From 760dd9ad11756bda554c243b8468afd4a42d5b10 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Wed, 6 Mar 2024 23:15:22 +0800 Subject: [PATCH 01/12] =?UTF-8?q?=E6=B7=BB=E5=8A=A0memory=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- e2e/base_test.go | 9 +- e2e/memory_test.go | 30 ++ memory/consumer.go | 155 ++++++++++ memory/consumergroup.go | 264 ++++++++++++++++++ .../equaldivide/balancer.go | 52 ++++ memory/mq.go | 139 +++++++++ memory/partition.go | 40 +++ memory/producer.go | 51 ++++ memory/produceridgetter/hash/get.go | 19 ++ memory/produceridgetter/hash/get_test.go | 1 + memory/topic.go | 86 ++++++ memory/type.go | 14 + 12 files changed, 855 insertions(+), 5 deletions(-) create mode 100644 e2e/memory_test.go create mode 100644 memory/consumer.go create mode 100644 memory/consumergroup.go create mode 100644 memory/consumerpartitionassigner/equaldivide/balancer.go create mode 100644 memory/mq.go create mode 100644 memory/partition.go create mode 100644 memory/producer.go create mode 100644 memory/produceridgetter/hash/get.go create mode 100644 memory/produceridgetter/hash/get_test.go create mode 100644 memory/topic.go create mode 100644 memory/type.go diff --git a/e2e/base_test.go b/e2e/base_test.go index f7ff251..7bf69f7 100644 --- a/e2e/base_test.go +++ b/e2e/base_test.go @@ -18,17 +18,17 @@ package e2e import ( "context" + "github.com/ecodeclub/mq-api/mqerr" + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" "log" "sync" "testing" "time" "github.com/ecodeclub/mq-api" - "github.com/ecodeclub/mq-api/mqerr" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "golang.org/x/sync/errgroup" ) type MQCreator interface { @@ -309,7 +309,6 @@ func (b *TestSuite) TestMQ_Close() { _, err = p.ProduceWithPartition(context.Background(), &mq.Message{}, partitions-1) require.ErrorIs(t, err, mqerr.ErrProducerIsClosed) - // 调用consumer上的方法会返回ErrConsumerIsClosed _, err = c.ConsumeChan(context.Background()) require.ErrorIs(t, err, mqerr.ErrConsumerIsClosed) @@ -417,7 +416,7 @@ func (b *TestSuite) TestProducer_ProduceWithPartition() { producers, _ := b.newProducersAndConsumers(t, topic14, partitions, producerInfo{Num: 1}, consumerInfo{}) ctx, cancelFunc := context.WithCancel(context.Background()) - cancelFunc() + defer cancelFunc() p := producers[0] diff --git a/e2e/memory_test.go b/e2e/memory_test.go new file mode 100644 index 0000000..2af0296 --- /dev/null +++ b/e2e/memory_test.go @@ -0,0 +1,30 @@ +//go:build e2e + +package e2e + +import ( + "context" + "testing" + + "github.com/ecodeclub/mq-api" + "github.com/ecodeclub/mq-api/memory" + "github.com/stretchr/testify/suite" +) + +func TestMemory(t *testing.T) { + suite.Run(t, NewTestSuite( + &MemoryTestSuite{}, + )) +} + +type MemoryTestSuite struct{} + +func (k *MemoryTestSuite) Create() mq.MQ { + memoryMq := memory.NewMQ() + + return memoryMq +} + +func (k *MemoryTestSuite) Ping(ctx context.Context) error { + return nil +} diff --git a/memory/consumer.go b/memory/consumer.go new file mode 100644 index 0000000..c07b289 --- /dev/null +++ b/memory/consumer.go @@ -0,0 +1,155 @@ +package memory + +import ( + "context" + "errors" + "log" + "sync" + "time" + + "github.com/ecodeclub/mq-api" + "github.com/ecodeclub/mq-api/mqerr" +) + +const ( + interval = 1 * time.Second + defaultMessageChannelSize = 1000 + // 每个分区取数据的上限 + limit = 25 +) + +var ErrConsumerClose = errors.New("消费者已关闭") + +type Consumer struct { + locker sync.RWMutex + name string + closed bool + // 用于存放分区号,每个元素就是一个分区号 + partitions []*Partition + partitionRecords []PartitionRecord + closeCh chan struct{} + msgCh chan *mq.Message + once sync.Once + reportCh chan *Event + receiveCh chan *Event +} + +func (c *Consumer) Consume(ctx context.Context) (*mq.Message, error) { + if c.isClosed() { + return nil, mqerr.ErrConsumerIsClosed + } + select { + case val, ok := <-c.msgCh: + if !ok { + return nil, mqerr.ErrConsumerIsClosed + } + return val, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// 启动Consume +func (c *Consumer) Run() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + log.Printf("消费者 %s 开始消费数据", c.name) + for idx, record := range c.partitionRecords { + msgs := c.partitions[record.Index].consumerMsg(record.Cursor, limit) + for _, msg := range msgs { + log.Printf("消费者 %s 消费数据 %v", c.name, msg) + c.msgCh <- msg + } + record.Cursor += len(msgs) + errCh := make(chan error, 1) + c.reportCh <- &Event{ + Type: ReportOffset, + Data: ReportData{ + Records: []PartitionRecord{record}, + ErrChan: errCh, + }, + } + log.Printf("获取是否消费成功", c.name) + err := <-errCh + if err != nil { + log.Printf("上报偏移量失败:%v", err) + break + } + close(errCh) + c.partitionRecords[idx] = record + } + log.Printf("消费者 %s 结束消费数据", c.name) + case event, ok := <-c.receiveCh: + log.Println(ok, "xxxxxxxxxxxxooooooooooo", c.name) + if !ok { + return + } + log.Println(event.Type, "xxxxxxxxxxxx", c.name) + // 处理各种事件 + c.Handle(event) + } + } +} + +func (c *Consumer) Handle(event *Event) { + switch event.Type { + // 服务端发起的重平衡事件 + case Rejoin: + // 消费者上报消费进度 + log.Printf("消费者 %s开始上报消费进度", c.name) + c.reportCh <- &Event{ + Type: RejoinAck, + Data: c.partitionRecords, + } + // 设置消费进度 + partitionInfo := <-c.receiveCh + log.Printf("消费者 %s接收到分区信息 %v", c.name, partitionInfo) + c.partitionRecords = partitionInfo.Data.([]PartitionRecord) + // 返回设置完成的信号 + c.reportCh <- &Event{ + Type: PartitionNotifyAck, + } + case Close: + c.Close() + } +} + +func (c *Consumer) ConsumeChan(ctx context.Context) (<-chan *mq.Message, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + if c.isClosed() { + return nil, mqerr.ErrConsumerIsClosed + } + return c.msgCh, nil +} + +func (c *Consumer) Close() error { + c.locker.Lock() + defer c.locker.Unlock() + c.once.Do(func() { + c.closed = true + c.reportCh <- &Event{ + Type: ExitGroup, + Data: c.closeCh, + } + log.Printf("消费者 %s 准备关闭", c.name) + // 等待服务端退出完成 + <-c.closeCh + // 关闭资源 + close(c.receiveCh) + close(c.msgCh) + log.Printf("消费者 %s 关闭成功", c.name) + }) + + return nil +} + +func (c *Consumer) isClosed() bool { + c.locker.RLock() + defer c.locker.RUnlock() + return c.closed +} diff --git a/memory/consumergroup.go b/memory/consumergroup.go new file mode 100644 index 0000000..93af58c --- /dev/null +++ b/memory/consumergroup.go @@ -0,0 +1,264 @@ +package memory + +import ( + "fmt" + "github.com/ecodeclub/ekit/syncx" + "github.com/ecodeclub/mq-api" + "github.com/pkg/errors" + "log" + "sync" + "sync/atomic" + "time" +) + +var ErrReportOffsetFail = errors.New("非平衡状态,无法上报偏移量") + +const ( + // ExitGroup 退出信号 + ExitGroup = "exit_group" + // ExitGroupAck 退出的确认信号 + ExitGroupAck = "exit_group_ack" + // ReportOffset 上报偏移量信号 + ReportOffset = "report_offset" + // ReportOffsetAck 上报偏移量确认信号 + ReportOffsetAck = "report_offset_ack" + // Rejoin 通知consumer重新加入消费组 + Rejoin = "rejoin" + // RejoinAck 表示客户端收到重新加入消费组的指令并将offset进行上报 + RejoinAck = "rejoin_ack" + Close = "close" + // PartitionNotify 下发分区情况事件 + PartitionNotify = "partition_notify" + // PartitionNotifyAck 下发分区情况确认 + PartitionNotifyAck = "partition_notify_ack" + + StatusStable = 1 // 稳定状态,可以正常的进行消费数据 + StatusBalancing = 2 +) + +// ConsumerGroup 表示消费组是并发安全的 +type ConsumerGroup struct { + name string + // 存储消费者元数据,键为消费者的名称 + consumers syncx.Map[string, *ConsumerMetaData] + // 消费者平衡器 + consumerPartitionBalancer ConsumerPartitionAssigner + // 分区消费记录 + partitionRecords syncx.Map[int, PartitionRecord] + // 分区 + partitions []*Partition + once sync.Once + status int32 + // 用于接受在重平衡阶段channel的返回数据 + balanceCh chan struct{} +} + +type PartitionRecord struct { + // 属于哪个分区 + Index int + // 消费进度 + Cursor int +} +type ReportData struct { + Records []PartitionRecord + ErrChan chan error +} + +type ConsumerMetaData struct { + // 用于消费者上报数据,如退出或加入消费组,上报消费者消费分区的偏移量 + reportCh chan *Event + // 用于消费组给消费者下发数据,如下发开始重平衡开始通知,告知消费者可以消费的channel + receiveCh chan *Event + // 消费者的名字 + name string +} + +type Event struct { + // 事件类型 + Type string + // 事件所需要处理的数据 + Data any +} + +func (c *ConsumerGroup) Handler(name string, event *Event) { + switch event.Type { + case ExitGroup: + closeCh := event.Data.(chan struct{}) + c.ExitGroup(name, closeCh) + case ReportOffset: + data := event.Data.(ReportData) + var err error + err = c.ReportOffset(data.Records) + data.ErrChan <- err + log.Printf("消费者%s上报offset成功", name) + case RejoinAck: + // consumer响应重平衡信号返回的数据,返回的是当前所有分区的偏移量 + records := event.Data.([]PartitionRecord) + // 不管上报成不成功 + _ = c.ReportOffset(records) + log.Printf("消费者%s成功接受到重平衡信号,并上报offset", name) + c.balanceCh <- struct{}{} + case PartitionNotifyAck: + log.Printf("消费者%s 成功设置分区信息", name) + c.balanceCh <- struct{}{} + } +} + +// ExitGroup 退出消费组 +func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { + // 把自己从消费组内摘除 + for { + if atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { + defer atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) + log.Printf("消费者 %s 准备退出消费组", name) + c.consumers.Delete(name) + c.reBalance() + log.Printf("给消费者 %s 发送退出确认信号", name) + close(closeCh) + log.Printf("消费者 %s 成功退出消费组", name) + return + } + } + +} + +// ReportOffset 上报偏移量 +func (c *ConsumerGroup) ReportOffset(records []PartitionRecord) error { + if atomic.LoadInt32(&c.status) != StatusStable { + return ErrReportOffsetFail + } + for _, record := range records { + c.partitionRecords.Store(record.Index, record) + } + return nil + +} + +func (c *ConsumerGroup) Close() { + c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + value.receiveCh <- &Event{ + Type: Close, + } + return true + }) + // 等待一秒退出完成 + time.Sleep(1 * time.Second) +} + +// reBalance 单独使用该方法是并发不安全的 +func (c *ConsumerGroup) reBalance() { + log.Println("开始重平衡") + // 通知每一个消费者进行偏移量的上报 + length := 0 + consumers := make([]string, 0, 20) + log.Println("开始给每个消费者,重平衡信号") + c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + log.Printf("开始通知消费者%s", key) + value.receiveCh <- &Event{ + Type: Rejoin, + } + consumers = append(consumers, key) + length++ + log.Printf("通知消费者%s成功", key) + return true + }) + number := 0 + // 等待所有消费者都接收到信号,并上报自己offset + for length > 0 { + select { + case <-c.balanceCh: + number++ + if number == length { + // 接收到所有信号 + log.Println("所有消费者已经接受到重平衡请求,并上报了消费进度") + consumerMap := c.consumerPartitionBalancer.AssignPartition(consumers, len(c.partitions)) + // 通知所有消费者分配 + log.Println("开始分配分区") + for consumerName, partitions := range consumerMap { + // 查找消费者所属的channel + log.Printf("消费者 %s 消费 %v 分区", consumerName, partitions) + consumer, ok := c.consumers.Load(consumerName) + if ok { + // 往每个消费者的receive_channel发送partition的信息 + records := make([]PartitionRecord, 0, len(partitions)) + for _, p := range partitions { + record, ok := c.partitionRecords.Load(p) + if ok { + records = append(records, record) + } + } + consumer.receiveCh <- &Event{ + Type: PartitionNotify, + Data: records, + } + // 等待消费者接收到并保存 + <-c.balanceCh + + } + } + log.Println("重平衡结束") + return + } + } + } + log.Println("重平衡结束") + +} + +// JoinGroup 加入消费组 +func (c *ConsumerGroup) JoinGroup() *Consumer { + for { + if atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { + var length int + c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + length++ + return true + }) + name := fmt.Sprintf("%s_%d", c.name, length) + reportCh := make(chan *Event, 16) + receiveCh := make(chan *Event, 16) + consumer := &Consumer{ + partitions: c.partitions, + receiveCh: receiveCh, + reportCh: reportCh, + name: name, + msgCh: make(chan *mq.Message, 1000), + partitionRecords: make([]PartitionRecord, 0), + closeCh: make(chan struct{}), + } + c.consumers.Store(name, &ConsumerMetaData{ + reportCh: reportCh, + receiveCh: receiveCh, + name: name, + }) + go c.HandleConsumerSignals(name, reportCh) + go consumer.Run() + log.Println(fmt.Sprintf("新建消费者 %s", name)) + // 重平衡分配分区 + c.reBalance() + atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) + return consumer + } + } +} + +// HandleConsumerSignals 处理消费者上报的事件 +func (c *ConsumerGroup) HandleConsumerSignals(name string, reportCh chan *Event) { + for { + select { + case event := <-reportCh: + c.Handler(name, event) + if event.Type == ExitGroup { + close(reportCh) + return + } + } + } +} + +func min(i, j int) int { + if i < j { + return i + } + return j +} diff --git a/memory/consumerpartitionassigner/equaldivide/balancer.go b/memory/consumerpartitionassigner/equaldivide/balancer.go new file mode 100644 index 0000000..ee956dc --- /dev/null +++ b/memory/consumerpartitionassigner/equaldivide/balancer.go @@ -0,0 +1,52 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package equaldivide + +type Balancer struct{} + +func (b *Balancer) AssignPartition(consumers []string, partitions int) map[string][]int { + result := make(map[string][]int) + consumerCount := len(consumers) + partitionPerConsumer := partitions / consumerCount + remainingPartitions := partitions % consumerCount + + // 初始化每个 consumer 对应的 partitions + for _, consumer := range consumers { + result[consumer] = make([]int, 0) + } + + // 平均分配 partitions + partitionIndex := 0 + for i := 0; i < consumerCount; i++ { + consumer := consumers[i] + numPartitions := partitionPerConsumer + // 如果还有剩余的 partitions,则将其分配给当前 consumer + if remainingPartitions > 0 { + numPartitions++ + remainingPartitions-- + } + // 分配 partitions + for j := 0; j < numPartitions; j++ { + result[consumer] = append(result[consumer], partitionIndex) + partitionIndex++ + } + } + + return result +} + +func NewBalancer() *Balancer { + return &Balancer{} +} diff --git a/memory/mq.go b/memory/mq.go new file mode 100644 index 0000000..72bf49a --- /dev/null +++ b/memory/mq.go @@ -0,0 +1,139 @@ +package memory + +import ( + "context" + "fmt" + "github.com/ecodeclub/mq-api/internal/pkg/validator" + "log" + "sync" + + "github.com/ecodeclub/ekit/syncx" + "github.com/ecodeclub/mq-api" + "github.com/ecodeclub/mq-api/mqerr" +) + +type MQ struct { + locker sync.RWMutex + closed bool + topics syncx.Map[string, *Topic] +} + +func NewMQ() mq.MQ { + return &MQ{ + topics: syncx.Map[string, *Topic]{}, + } +} + +func (m *MQ) CreateTopic(ctx context.Context, topic string, partitions int) error { + if !validator.IsValidTopic(topic) { + return fmt.Errorf("%w: %s", mqerr.ErrInvalidTopic, topic) + } + if ctx.Err() != nil { + return ctx.Err() + } + m.locker.Lock() + defer m.locker.Unlock() + if m.closed { + return mqerr.ErrMQIsClosed + } + _, ok := m.topics.Load(topic) + if ok { + return mqerr.ErrInvalidTopic + } + if partitions <= 0 { + return mqerr.ErrInvalidPartition + } + m.topics.Store(topic, NewTopic(topic, partitions)) + return nil +} + +func (m *MQ) Producer(topic string) (mq.Producer, error) { + m.locker.Lock() + defer m.locker.Unlock() + if m.closed { + return nil, mqerr.ErrMQIsClosed + } + t, ok := m.topics.Load(topic) + if !ok { + return nil, mqerr.ErrUnknownTopic + } + p := &Producer{ + locker: sync.RWMutex{}, + t: t, + } + err := t.addProducer(p) + if err != nil { + return nil, err + } + return p, nil +} + +func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { + m.locker.Lock() + defer m.locker.Unlock() + if m.closed { + return nil, mqerr.ErrMQIsClosed + } + t, ok := m.topics.Load(topic) + if !ok { + return nil, mqerr.ErrUnknownTopic + } + group, ok := t.consumerGroups.Load(groupID) + if !ok { + group = &ConsumerGroup{ + name: groupID, + consumers: syncx.Map[string, *ConsumerMetaData]{}, + consumerPartitionBalancer: t.consumerPartitionBalancer, + partitions: t.partitions, + balanceCh: make(chan struct{}, 10), + status: StatusStable, + } + // 初始化分区消费进度 + partitionRecords := syncx.Map[int, PartitionRecord]{} + for idx, _ := range t.partitions { + partitionRecords.Store(idx, PartitionRecord{ + Index: idx, + Cursor: 0, + }) + } + group.partitionRecords = partitionRecords + } + consumer := group.JoinGroup() + t.consumerGroups.Store(groupID, group) + return consumer, nil +} + +func (m *MQ) Close() error { + m.locker.Lock() + defer m.locker.Unlock() + m.closed = true + m.topics.Range(func(key string, value *Topic) bool { + err := value.Close() + if err != nil { + log.Printf("topic: %s关闭失败 %v", key, err) + } + return true + }) + + return nil +} + +func (m *MQ) DeleteTopics(ctx context.Context, topics ...string) error { + m.locker.Lock() + defer m.locker.Unlock() + if m.closed { + return mqerr.ErrMQIsClosed + } + if ctx.Err() != nil { + return ctx.Err() + } + for _, t := range topics { + topic, ok := m.topics.Load(t) + if ok { + topic.Close() + m.topics.Delete(t) + } + + } + return nil +} diff --git a/memory/partition.go b/memory/partition.go new file mode 100644 index 0000000..348f54d --- /dev/null +++ b/memory/partition.go @@ -0,0 +1,40 @@ +package memory + +import ( + "sync" + + "github.com/ecodeclub/ekit/list" + "github.com/ecodeclub/mq-api" +) + +// Partition 表示分区 是并发安全的 +const ( + defaultPartitionCap = 64 +) + +type Partition struct { + locker sync.RWMutex + data *list.ArrayList[*mq.Message] +} + +func NewPartition() *Partition { + return &Partition{ + data: list.NewArrayList[*mq.Message](defaultPartitionCap), + } +} + +func (p *Partition) sendMsg(msg *mq.Message) { + p.locker.Lock() + defer p.locker.Unlock() + msg.Offset = int64(p.data.Len()) + _ = p.data.Append(msg) +} + +func (p *Partition) consumerMsg(cursor, limit int) []*mq.Message { + p.locker.RLock() + defer p.locker.RUnlock() + wantLen := cursor + limit + 1 + length := min(wantLen, p.data.Len()) + res := p.data.AsSlice()[cursor:length] + return res +} diff --git a/memory/producer.go b/memory/producer.go new file mode 100644 index 0000000..235144b --- /dev/null +++ b/memory/producer.go @@ -0,0 +1,51 @@ +package memory + +import ( + "context" + "github.com/ecodeclub/mq-api/mqerr" + "sync" + + "github.com/ecodeclub/mq-api" +) + +type Producer struct { + t *Topic + closed bool + locker sync.RWMutex +} + +func (p *Producer) Produce(ctx context.Context, m *mq.Message) (*mq.ProducerResult, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + // 将partition设为 -1,按系统分配算法分配到某个分区 + if p.isClosed() { + return nil, mqerr.ErrProducerIsClosed + } + err := p.t.addMessage(m) + return &mq.ProducerResult{}, err +} + +func (p *Producer) ProduceWithPartition(ctx context.Context, m *mq.Message, partition int) (*mq.ProducerResult, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + if p.isClosed() { + return nil, mqerr.ErrProducerIsClosed + } + err := p.t.addMessage(m, int64(partition)) + return &mq.ProducerResult{}, err +} + +func (p *Producer) Close() error { + p.locker.Lock() + defer p.locker.Unlock() + p.closed = true + return nil +} + +func (p *Producer) isClosed() bool { + p.locker.RLock() + defer p.locker.RUnlock() + return p.closed +} diff --git a/memory/produceridgetter/hash/get.go b/memory/produceridgetter/hash/get.go new file mode 100644 index 0000000..7c74645 --- /dev/null +++ b/memory/produceridgetter/hash/get.go @@ -0,0 +1,19 @@ +package hash + +import "hash/fnv" + +type Getter struct { + Partition int +} + +// GetPartitionId 暂时使用hash,保证同一个key的值,在同一个分区。 +func (g *Getter) GetPartitionId(key string) int64 { + return hashString(key, g.Partition) +} + +func hashString(s string, numBuckets int) int64 { + h := fnv.New32a() + h.Write([]byte(s)) + hash := h.Sum32() + return int64(hash % uint32(numBuckets)) +} diff --git a/memory/produceridgetter/hash/get_test.go b/memory/produceridgetter/hash/get_test.go new file mode 100644 index 0000000..7adc22f --- /dev/null +++ b/memory/produceridgetter/hash/get_test.go @@ -0,0 +1 @@ +package hash diff --git a/memory/topic.go b/memory/topic.go new file mode 100644 index 0000000..d84672b --- /dev/null +++ b/memory/topic.go @@ -0,0 +1,86 @@ +package memory + +import ( + "log" + "sync" + + "github.com/ecodeclub/ekit/syncx" + "github.com/ecodeclub/mq-api" + "github.com/ecodeclub/mq-api/memory/consumerpartitionassigner/equaldivide" + "github.com/ecodeclub/mq-api/memory/produceridgetter/hash" + "github.com/ecodeclub/mq-api/mqerr" +) + +type Topic struct { + // 用[]*mq.Message表示一个分区 + locker sync.RWMutex + closed bool + name string + partitions []*Partition + producers []mq.Producer + // 消费组 + consumerGroups syncx.Map[string, *ConsumerGroup] + // 生产消息的时候获取分区号 + partitionIDGetter PartitionIDGetter + consumerPartitionBalancer ConsumerPartitionAssigner +} +type TopicOption func(t *Topic) + +func NewTopic(name string, partitions int) *Topic { + t := &Topic{ + name: name, + consumerGroups: syncx.Map[string, *ConsumerGroup]{}, + consumerPartitionBalancer: equaldivide.NewBalancer(), + partitionIDGetter: &hash.Getter{Partition: partitions}, + } + partitionList := make([]*Partition, 0, partitions) + for i := 0; i < partitions; i++ { + partitionList = append(partitionList, NewPartition()) + } + t.partitions = partitionList + return t +} +func (t *Topic) addProducer(producer mq.Producer) error { + t.locker.Lock() + defer t.locker.Unlock() + if t.closed { + return mqerr.ErrMQIsClosed + } + t.producers = append(t.producers, producer) + return nil +} + +// addMessage 往分区里面添加消息 +func (t *Topic) addMessage(msg *mq.Message, partition ...int64) error { + var partitionID int64 + if len(partition) == 0 { + partitionID = t.partitionIDGetter.GetPartitionId(string(msg.Key)) + } else if len(partition) == 1 { + partitionID = partition[0] + } else { + return mqerr.ErrInvalidPartition + } + if partitionID < 0 || int(partitionID) >= len(t.partitions) { + return mqerr.ErrInvalidPartition + } + msg.Topic = t.name + msg.Partition = partitionID + t.partitions[partitionID].sendMsg(msg) + log.Printf("生产消息 %s,消息为 %s", t.name, msg.Value) + return nil +} + +func (t *Topic) Close() error { + t.locker.Lock() + defer t.locker.Unlock() + if !t.closed { + t.consumerGroups.Range(func(key string, value *ConsumerGroup) bool { + value.Close() + return true + }) + for _, producer := range t.producers { + _ = producer.Close() + } + } + return nil +} diff --git a/memory/type.go b/memory/type.go new file mode 100644 index 0000000..e79f867 --- /dev/null +++ b/memory/type.go @@ -0,0 +1,14 @@ +package memory + +// PartitionIDGetter 此抽象用于Producer获取对应分区号 +type PartitionIDGetter interface { + // GetPartitionId 用于Producer获取分区号,返回值就是分区号 + GetPartitionId(key string) int64 +} + +// ConsumerPartitionAssigner 此抽象是给消费组使用,用于将分区分配给消费组内的消费者。 +type ConsumerPartitionAssigner interface { + // AssignPartition consumerList为消费组内的所有消费者, + // partitions表示分区数,返回值为map[name][]int name对应consumerList的索引,对应的值消费者可消费的分区 + AssignPartition(consumers []string, partitions int) map[string][]int +} From 065f859049ae53a9455e01977b80ba745a92468b Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Wed, 6 Mar 2024 23:17:41 +0800 Subject: [PATCH 02/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/consumer.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/memory/consumer.go b/memory/consumer.go index c07b289..9e3fa08 100644 --- a/memory/consumer.go +++ b/memory/consumer.go @@ -72,7 +72,6 @@ func (c *Consumer) Run() { ErrChan: errCh, }, } - log.Printf("获取是否消费成功", c.name) err := <-errCh if err != nil { log.Printf("上报偏移量失败:%v", err) @@ -83,11 +82,9 @@ func (c *Consumer) Run() { } log.Printf("消费者 %s 结束消费数据", c.name) case event, ok := <-c.receiveCh: - log.Println(ok, "xxxxxxxxxxxxooooooooooo", c.name) if !ok { return } - log.Println(event.Type, "xxxxxxxxxxxx", c.name) // 处理各种事件 c.Handle(event) } From 8e1f243ac69a202e0b6b5e7620f792fe3726d4f8 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Wed, 6 Mar 2024 23:18:17 +0800 Subject: [PATCH 03/12] make fmt --- e2e/base_test.go | 7 ++++--- memory/consumergroup.go | 10 ++++------ memory/mq.go | 5 +++-- memory/producer.go | 3 ++- memory/topic.go | 1 + 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/e2e/base_test.go b/e2e/base_test.go index 7bf69f7..574ad49 100644 --- a/e2e/base_test.go +++ b/e2e/base_test.go @@ -18,14 +18,15 @@ package e2e import ( "context" - "github.com/ecodeclub/mq-api/mqerr" - "github.com/stretchr/testify/assert" - "golang.org/x/sync/errgroup" "log" "sync" "testing" "time" + "github.com/ecodeclub/mq-api/mqerr" + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" + "github.com/ecodeclub/mq-api" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" diff --git a/memory/consumergroup.go b/memory/consumergroup.go index 93af58c..11f1dab 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -2,13 +2,14 @@ package memory import ( "fmt" - "github.com/ecodeclub/ekit/syncx" - "github.com/ecodeclub/mq-api" - "github.com/pkg/errors" "log" "sync" "sync/atomic" "time" + + "github.com/ecodeclub/ekit/syncx" + "github.com/ecodeclub/mq-api" + "github.com/pkg/errors" ) var ErrReportOffsetFail = errors.New("非平衡状态,无法上报偏移量") @@ -119,7 +120,6 @@ func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { return } } - } // ReportOffset 上报偏移量 @@ -131,7 +131,6 @@ func (c *ConsumerGroup) ReportOffset(records []PartitionRecord) error { c.partitionRecords.Store(record.Index, record) } return nil - } func (c *ConsumerGroup) Close() { @@ -202,7 +201,6 @@ func (c *ConsumerGroup) reBalance() { } } log.Println("重平衡结束") - } // JoinGroup 加入消费组 diff --git a/memory/mq.go b/memory/mq.go index 72bf49a..720bfc6 100644 --- a/memory/mq.go +++ b/memory/mq.go @@ -3,10 +3,11 @@ package memory import ( "context" "fmt" - "github.com/ecodeclub/mq-api/internal/pkg/validator" "log" "sync" + "github.com/ecodeclub/mq-api/internal/pkg/validator" + "github.com/ecodeclub/ekit/syncx" "github.com/ecodeclub/mq-api" "github.com/ecodeclub/mq-api/mqerr" @@ -90,7 +91,7 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { } // 初始化分区消费进度 partitionRecords := syncx.Map[int, PartitionRecord]{} - for idx, _ := range t.partitions { + for idx := range t.partitions { partitionRecords.Store(idx, PartitionRecord{ Index: idx, Cursor: 0, diff --git a/memory/producer.go b/memory/producer.go index 235144b..ab5f42e 100644 --- a/memory/producer.go +++ b/memory/producer.go @@ -2,9 +2,10 @@ package memory import ( "context" - "github.com/ecodeclub/mq-api/mqerr" "sync" + "github.com/ecodeclub/mq-api/mqerr" + "github.com/ecodeclub/mq-api" ) diff --git a/memory/topic.go b/memory/topic.go index d84672b..2f4f4c2 100644 --- a/memory/topic.go +++ b/memory/topic.go @@ -40,6 +40,7 @@ func NewTopic(name string, partitions int) *Topic { t.partitions = partitionList return t } + func (t *Topic) addProducer(producer mq.Producer) error { t.locker.Lock() defer t.locker.Unlock() From 5b949ca62d0aada2d51e921feed191505bc98503 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Sat, 9 Mar 2024 17:47:16 +0800 Subject: [PATCH 04/12] =?UTF-8?q?=E9=9D=99=E6=80=81=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- e2e/memory_test.go | 14 ++ memory/consumer.go | 16 ++- memory/consumergroup.go | 132 ++++++++++-------- .../equaldivide/balancer.go | 2 - .../equaldivide/balancer_test.go | 63 +++++++++ memory/mq.go | 20 ++- memory/partition.go | 14 ++ memory/producer.go | 14 ++ memory/produceridgetter/hash/get.go | 18 ++- memory/produceridgetter/hash/get_test.go | 31 ++++ memory/topic.go | 24 +++- memory/type.go | 16 ++- 12 files changed, 293 insertions(+), 71 deletions(-) create mode 100644 memory/consumerpartitionassigner/equaldivide/balancer_test.go diff --git a/e2e/memory_test.go b/e2e/memory_test.go index 2af0296..8f51c82 100644 --- a/e2e/memory_test.go +++ b/e2e/memory_test.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + //go:build e2e package e2e diff --git a/memory/consumer.go b/memory/consumer.go index 9e3fa08..2966923 100644 --- a/memory/consumer.go +++ b/memory/consumer.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( @@ -104,7 +118,7 @@ func (c *Consumer) Handle(event *Event) { // 设置消费进度 partitionInfo := <-c.receiveCh log.Printf("消费者 %s接收到分区信息 %v", c.name, partitionInfo) - c.partitionRecords = partitionInfo.Data.([]PartitionRecord) + c.partitionRecords, _ = partitionInfo.Data.([]PartitionRecord) // 返回设置完成的信号 c.reportCh <- &Event{ Type: PartitionNotifyAck, diff --git a/memory/consumergroup.go b/memory/consumergroup.go index 11f1dab..1ec2402 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -1,28 +1,42 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( "fmt" "log" - "sync" "sync/atomic" "time" - "github.com/ecodeclub/ekit/syncx" "github.com/ecodeclub/mq-api" + + "github.com/ecodeclub/ekit/syncx" "github.com/pkg/errors" ) var ErrReportOffsetFail = errors.New("非平衡状态,无法上报偏移量") const ( + consumerCap = 16 + defaultEventCap = 16 + msgChannelLength = 1000 + // ExitGroup 退出信号 ExitGroup = "exit_group" - // ExitGroupAck 退出的确认信号 - ExitGroupAck = "exit_group_ack" // ReportOffset 上报偏移量信号 ReportOffset = "report_offset" - // ReportOffsetAck 上报偏移量确认信号 - ReportOffsetAck = "report_offset_ack" // Rejoin 通知consumer重新加入消费组 Rejoin = "rejoin" // RejoinAck 表示客户端收到重新加入消费组的指令并将offset进行上报 @@ -45,10 +59,9 @@ type ConsumerGroup struct { // 消费者平衡器 consumerPartitionBalancer ConsumerPartitionAssigner // 分区消费记录 - partitionRecords syncx.Map[int, PartitionRecord] + partitionRecords *syncx.Map[int, PartitionRecord] // 分区 partitions []*Partition - once sync.Once status int32 // 用于接受在重平衡阶段channel的返回数据 balanceCh chan struct{} @@ -84,17 +97,17 @@ type Event struct { func (c *ConsumerGroup) Handler(name string, event *Event) { switch event.Type { case ExitGroup: - closeCh := event.Data.(chan struct{}) + closeCh, _ := event.Data.(chan struct{}) c.ExitGroup(name, closeCh) case ReportOffset: - data := event.Data.(ReportData) + data, _ := event.Data.(ReportData) var err error err = c.ReportOffset(data.Records) data.ErrChan <- err log.Printf("消费者%s上报offset成功", name) case RejoinAck: // consumer响应重平衡信号返回的数据,返回的是当前所有分区的偏移量 - records := event.Data.([]PartitionRecord) + records, _ := event.Data.([]PartitionRecord) // 不管上报成不成功 _ = c.ReportOffset(records) log.Printf("消费者%s成功接受到重平衡信号,并上报offset", name) @@ -109,16 +122,17 @@ func (c *ConsumerGroup) Handler(name string, event *Event) { func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { // 把自己从消费组内摘除 for { - if atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { - defer atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) - log.Printf("消费者 %s 准备退出消费组", name) - c.consumers.Delete(name) - c.reBalance() - log.Printf("给消费者 %s 发送退出确认信号", name) - close(closeCh) - log.Printf("消费者 %s 成功退出消费组", name) - return + if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { + continue } + log.Printf("消费者 %s 准备退出消费组", name) + c.consumers.Delete(name) + c.reBalance() + log.Printf("给消费者 %s 发送退出确认信号", name) + close(closeCh) + log.Printf("消费者 %s 成功退出消费组", name) + atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) + return } } @@ -149,7 +163,7 @@ func (c *ConsumerGroup) reBalance() { log.Println("开始重平衡") // 通知每一个消费者进行偏移量的上报 length := 0 - consumers := make([]string, 0, 20) + consumers := make([]string, 0, consumerCap) log.Println("开始给每个消费者,重平衡信号") c.consumers.Range(func(key string, value *ConsumerMetaData) bool { log.Printf("开始通知消费者%s", key) @@ -198,6 +212,8 @@ func (c *ConsumerGroup) reBalance() { log.Println("重平衡结束") return } + default: + } } log.Println("重平衡结束") @@ -206,50 +222,48 @@ func (c *ConsumerGroup) reBalance() { // JoinGroup 加入消费组 func (c *ConsumerGroup) JoinGroup() *Consumer { for { - if atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { - var length int - c.consumers.Range(func(key string, value *ConsumerMetaData) bool { - length++ - return true - }) - name := fmt.Sprintf("%s_%d", c.name, length) - reportCh := make(chan *Event, 16) - receiveCh := make(chan *Event, 16) - consumer := &Consumer{ - partitions: c.partitions, - receiveCh: receiveCh, - reportCh: reportCh, - name: name, - msgCh: make(chan *mq.Message, 1000), - partitionRecords: make([]PartitionRecord, 0), - closeCh: make(chan struct{}), - } - c.consumers.Store(name, &ConsumerMetaData{ - reportCh: reportCh, - receiveCh: receiveCh, - name: name, - }) - go c.HandleConsumerSignals(name, reportCh) - go consumer.Run() - log.Println(fmt.Sprintf("新建消费者 %s", name)) - // 重平衡分配分区 - c.reBalance() - atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) - return consumer + if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { + continue + } + var length int + c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + length++ + return true + }) + name := fmt.Sprintf("%s_%d", c.name, length) + reportCh := make(chan *Event, defaultEventCap) + receiveCh := make(chan *Event, defaultEventCap) + consumer := &Consumer{ + partitions: c.partitions, + receiveCh: receiveCh, + reportCh: reportCh, + name: name, + msgCh: make(chan *mq.Message, msgChannelLength), + partitionRecords: []PartitionRecord{}, + closeCh: make(chan struct{}), } + c.consumers.Store(name, &ConsumerMetaData{ + reportCh: reportCh, + receiveCh: receiveCh, + name: name, + }) + go c.HandleConsumerSignals(name, reportCh) + go consumer.Run() + log.Printf("新建消费者 %s", name) + // 重平衡分配分区 + c.reBalance() + atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) + return consumer } } // HandleConsumerSignals 处理消费者上报的事件 func (c *ConsumerGroup) HandleConsumerSignals(name string, reportCh chan *Event) { - for { - select { - case event := <-reportCh: - c.Handler(name, event) - if event.Type == ExitGroup { - close(reportCh) - return - } + for event := range reportCh { + c.Handler(name, event) + if event.Type == ExitGroup { + close(reportCh) + return } } } diff --git a/memory/consumerpartitionassigner/equaldivide/balancer.go b/memory/consumerpartitionassigner/equaldivide/balancer.go index ee956dc..d8c7364 100644 --- a/memory/consumerpartitionassigner/equaldivide/balancer.go +++ b/memory/consumerpartitionassigner/equaldivide/balancer.go @@ -26,7 +26,6 @@ func (b *Balancer) AssignPartition(consumers []string, partitions int) map[strin for _, consumer := range consumers { result[consumer] = make([]int, 0) } - // 平均分配 partitions partitionIndex := 0 for i := 0; i < consumerCount; i++ { @@ -43,7 +42,6 @@ func (b *Balancer) AssignPartition(consumers []string, partitions int) map[strin partitionIndex++ } } - return result } diff --git a/memory/consumerpartitionassigner/equaldivide/balancer_test.go b/memory/consumerpartitionassigner/equaldivide/balancer_test.go new file mode 100644 index 0000000..a6f6ea6 --- /dev/null +++ b/memory/consumerpartitionassigner/equaldivide/balancer_test.go @@ -0,0 +1,63 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package equaldivide + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBalancer_AssignPartition(t *testing.T) { + t.Parallel() + balancer := NewBalancer() + testcases := []struct { + name string + consumers []string + partition int + wantAnswer map[string][]int + }{ + { + name: "分区书超过consumer个数", + consumers: []string{"c1", "c2", "c3", "c4"}, + partition: 5, + wantAnswer: map[string][]int{ + "c1": {0, 1}, + "c2": {2}, + "c3": {3}, + "c4": {4}, + }, + }, + { + name: "分区数小于consumer个数", + consumers: []string{"c1", "c2", "c3", "c4"}, + partition: 3, + wantAnswer: map[string][]int{ + "c1": {0}, + "c2": {1}, + "c3": {2}, + "c4": {}, + }, + }, + } + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + actualVal := balancer.AssignPartition(tc.consumers, tc.partition) + assert.Equal(t, tc.wantAnswer, actualVal) + }) + } +} diff --git a/memory/mq.go b/memory/mq.go index 720bfc6..ca93ce9 100644 --- a/memory/mq.go +++ b/memory/mq.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( @@ -13,6 +27,8 @@ import ( "github.com/ecodeclub/mq-api/mqerr" ) +const defaultBalanceChLen = 10 + type MQ struct { locker sync.RWMutex closed bool @@ -86,7 +102,7 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { consumers: syncx.Map[string, *ConsumerMetaData]{}, consumerPartitionBalancer: t.consumerPartitionBalancer, partitions: t.partitions, - balanceCh: make(chan struct{}, 10), + balanceCh: make(chan struct{}, defaultBalanceChLen), status: StatusStable, } // 初始化分区消费进度 @@ -97,7 +113,7 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { Cursor: 0, }) } - group.partitionRecords = partitionRecords + group.partitionRecords = &partitionRecords } consumer := group.JoinGroup() t.consumerGroups.Store(groupID, group) diff --git a/memory/partition.go b/memory/partition.go index 348f54d..9342dc1 100644 --- a/memory/partition.go +++ b/memory/partition.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( diff --git a/memory/producer.go b/memory/producer.go index ab5f42e..2186b21 100644 --- a/memory/producer.go +++ b/memory/producer.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( diff --git a/memory/produceridgetter/hash/get.go b/memory/produceridgetter/hash/get.go index 7c74645..a6836e3 100644 --- a/memory/produceridgetter/hash/get.go +++ b/memory/produceridgetter/hash/get.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package hash import "hash/fnv" @@ -6,8 +20,8 @@ type Getter struct { Partition int } -// GetPartitionId 暂时使用hash,保证同一个key的值,在同一个分区。 -func (g *Getter) GetPartitionId(key string) int64 { +// GetPartitionID 暂时使用hash,保证同一个key的值,在同一个分区。 +func (g *Getter) GetPartitionID(key string) int64 { return hashString(key, g.Partition) } diff --git a/memory/produceridgetter/hash/get_test.go b/memory/produceridgetter/hash/get_test.go index 7adc22f..6ea0383 100644 --- a/memory/produceridgetter/hash/get_test.go +++ b/memory/produceridgetter/hash/get_test.go @@ -1 +1,32 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package hash + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetter(t *testing.T) { + t.Parallel() + // 测试两个相同的key返回的partition是同一个 + getter := Getter{ + 3, + } + partition1 := getter.GetPartitionID("msg1") + partition2 := getter.GetPartitionID("msg2") + assert.Equal(t, partition1, partition2) +} diff --git a/memory/topic.go b/memory/topic.go index 2f4f4c2..6a3ebb8 100644 --- a/memory/topic.go +++ b/memory/topic.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( @@ -54,11 +68,13 @@ func (t *Topic) addProducer(producer mq.Producer) error { // addMessage 往分区里面添加消息 func (t *Topic) addMessage(msg *mq.Message, partition ...int64) error { var partitionID int64 - if len(partition) == 0 { - partitionID = t.partitionIDGetter.GetPartitionId(string(msg.Key)) - } else if len(partition) == 1 { + partitionLen := len(partition) + switch partitionLen { + case 0: + partitionID = t.partitionIDGetter.GetPartitionID(string(msg.Key)) + case 1: partitionID = partition[0] - } else { + default: return mqerr.ErrInvalidPartition } if partitionID < 0 || int(partitionID) >= len(t.partitions) { diff --git a/memory/type.go b/memory/type.go index e79f867..2e6ce7f 100644 --- a/memory/type.go +++ b/memory/type.go @@ -1,9 +1,23 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory // PartitionIDGetter 此抽象用于Producer获取对应分区号 type PartitionIDGetter interface { // GetPartitionId 用于Producer获取分区号,返回值就是分区号 - GetPartitionId(key string) int64 + GetPartitionID(key string) int64 } // ConsumerPartitionAssigner 此抽象是给消费组使用,用于将分区分配给消费组内的消费者。 From 2e9035470a5bf1465af41219a5d657b2a9d992e4 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Tue, 23 Apr 2024 22:36:46 +0800 Subject: [PATCH 05/12] =?UTF-8?q?=E6=B7=BB=E5=8A=A0partition=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/consumer.go | 2 +- memory/consumergroup.go | 2 + .../equaldivide/balancer_test.go | 12 +++++- memory/partition.go | 6 +-- memory/partition_test.go | 42 +++++++++++++++++++ memory/produceridgetter/hash/get.go | 8 ++-- memory/produceridgetter/hash/get_test.go | 4 +- memory/topic.go | 6 +-- memory/type.go | 7 ++-- 9 files changed, 71 insertions(+), 18 deletions(-) create mode 100644 memory/partition_test.go diff --git a/memory/consumer.go b/memory/consumer.go index 2966923..b52ed51 100644 --- a/memory/consumer.go +++ b/memory/consumer.go @@ -72,7 +72,7 @@ func (c *Consumer) Run() { case <-ticker.C: log.Printf("消费者 %s 开始消费数据", c.name) for idx, record := range c.partitionRecords { - msgs := c.partitions[record.Index].consumerMsg(record.Cursor, limit) + msgs := c.partitions[record.Index].getBatch(record.Cursor, limit) for _, msg := range msgs { log.Printf("消费者 %s 消费数据 %v", c.name, msg) c.msgCh <- msg diff --git a/memory/consumergroup.go b/memory/consumergroup.go index 1ec2402..3fa810c 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -123,6 +123,7 @@ func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { // 把自己从消费组内摘除 for { if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { + time.Sleep(100 * time.Millisecond) continue } log.Printf("消费者 %s 准备退出消费组", name) @@ -223,6 +224,7 @@ func (c *ConsumerGroup) reBalance() { func (c *ConsumerGroup) JoinGroup() *Consumer { for { if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { + time.Sleep(100 * time.Millisecond) continue } var length int diff --git a/memory/consumerpartitionassigner/equaldivide/balancer_test.go b/memory/consumerpartitionassigner/equaldivide/balancer_test.go index a6f6ea6..7dc188d 100644 --- a/memory/consumerpartitionassigner/equaldivide/balancer_test.go +++ b/memory/consumerpartitionassigner/equaldivide/balancer_test.go @@ -30,7 +30,7 @@ func TestBalancer_AssignPartition(t *testing.T) { wantAnswer map[string][]int }{ { - name: "分区书超过consumer个数", + name: "分区数超过consumer个数", consumers: []string{"c1", "c2", "c3", "c4"}, partition: 5, wantAnswer: map[string][]int{ @@ -51,6 +51,16 @@ func TestBalancer_AssignPartition(t *testing.T) { "c4": {}, }, }, + { + name: "分区数等于consumer个数", + consumers: []string{"c1", "c2", "c3"}, + partition: 3, + wantAnswer: map[string][]int{ + "c1": {0}, + "c2": {1}, + "c3": {2}, + }, + }, } for _, tc := range testcases { tc := tc diff --git a/memory/partition.go b/memory/partition.go index 9342dc1..1114aef 100644 --- a/memory/partition.go +++ b/memory/partition.go @@ -37,17 +37,17 @@ func NewPartition() *Partition { } } -func (p *Partition) sendMsg(msg *mq.Message) { +func (p *Partition) append(msg *mq.Message) { p.locker.Lock() defer p.locker.Unlock() msg.Offset = int64(p.data.Len()) _ = p.data.Append(msg) } -func (p *Partition) consumerMsg(cursor, limit int) []*mq.Message { +func (p *Partition) getBatch(cursor, limit int) []*mq.Message { p.locker.RLock() defer p.locker.RUnlock() - wantLen := cursor + limit + 1 + wantLen := cursor + limit length := min(wantLen, p.data.Len()) res := p.data.AsSlice()[cursor:length] return res diff --git a/memory/partition_test.go b/memory/partition_test.go new file mode 100644 index 0000000..b327a4d --- /dev/null +++ b/memory/partition_test.go @@ -0,0 +1,42 @@ +package memory + +import ( + "github.com/ecodeclub/mq-api" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_Partition(t *testing.T) { + p := NewPartition() + for i := 0; i < 5; i++ { + msg := &mq.Message{Partition: int64(i)} + p.append(msg) + } + msgs := p.getBatch(2, 2) + assert.Equal(t, []*mq.Message{ + { + Partition: 2, + Offset: 2, + }, + { + Partition: 3, + Offset: 3, + }, + }, msgs) + msgs = p.getBatch(2, 5) + assert.Equal(t, []*mq.Message{ + { + Partition: 2, + Offset: 2, + }, + { + Partition: 3, + Offset: 3, + }, + { + Partition: 4, + Offset: 4, + }, + }, msgs) + +} diff --git a/memory/produceridgetter/hash/get.go b/memory/produceridgetter/hash/get.go index a6836e3..18da3ea 100644 --- a/memory/produceridgetter/hash/get.go +++ b/memory/produceridgetter/hash/get.go @@ -17,12 +17,12 @@ package hash import "hash/fnv" type Getter struct { - Partition int + Partitions int } -// GetPartitionID 暂时使用hash,保证同一个key的值,在同一个分区。 -func (g *Getter) GetPartitionID(key string) int64 { - return hashString(key, g.Partition) +// PartitionID 暂时使用hash,保证同一个key的值,在同一个分区。 +func (g *Getter) PartitionID(key string) int64 { + return hashString(key, g.Partitions) } func hashString(s string, numBuckets int) int64 { diff --git a/memory/produceridgetter/hash/get_test.go b/memory/produceridgetter/hash/get_test.go index 6ea0383..cab704f 100644 --- a/memory/produceridgetter/hash/get_test.go +++ b/memory/produceridgetter/hash/get_test.go @@ -26,7 +26,7 @@ func TestGetter(t *testing.T) { getter := Getter{ 3, } - partition1 := getter.GetPartitionID("msg1") - partition2 := getter.GetPartitionID("msg2") + partition1 := getter.PartitionID("msg1") + partition2 := getter.PartitionID("msg1") assert.Equal(t, partition1, partition2) } diff --git a/memory/topic.go b/memory/topic.go index 6a3ebb8..a3f5884 100644 --- a/memory/topic.go +++ b/memory/topic.go @@ -45,7 +45,7 @@ func NewTopic(name string, partitions int) *Topic { name: name, consumerGroups: syncx.Map[string, *ConsumerGroup]{}, consumerPartitionBalancer: equaldivide.NewBalancer(), - partitionIDGetter: &hash.Getter{Partition: partitions}, + partitionIDGetter: &hash.Getter{Partitions: partitions}, } partitionList := make([]*Partition, 0, partitions) for i := 0; i < partitions; i++ { @@ -71,7 +71,7 @@ func (t *Topic) addMessage(msg *mq.Message, partition ...int64) error { partitionLen := len(partition) switch partitionLen { case 0: - partitionID = t.partitionIDGetter.GetPartitionID(string(msg.Key)) + partitionID = t.partitionIDGetter.PartitionID(string(msg.Key)) case 1: partitionID = partition[0] default: @@ -82,7 +82,7 @@ func (t *Topic) addMessage(msg *mq.Message, partition ...int64) error { } msg.Topic = t.name msg.Partition = partitionID - t.partitions[partitionID].sendMsg(msg) + t.partitions[partitionID].append(msg) log.Printf("生产消息 %s,消息为 %s", t.name, msg.Value) return nil } diff --git a/memory/type.go b/memory/type.go index 2e6ce7f..f45bcd7 100644 --- a/memory/type.go +++ b/memory/type.go @@ -16,13 +16,12 @@ package memory // PartitionIDGetter 此抽象用于Producer获取对应分区号 type PartitionIDGetter interface { - // GetPartitionId 用于Producer获取分区号,返回值就是分区号 - GetPartitionID(key string) int64 + // PartitionID 用于Producer获取分区号,返回值就是分区号 + PartitionID(key string) int64 } // ConsumerPartitionAssigner 此抽象是给消费组使用,用于将分区分配给消费组内的消费者。 type ConsumerPartitionAssigner interface { - // AssignPartition consumerList为消费组内的所有消费者, - // partitions表示分区数,返回值为map[name][]int name对应consumerList的索引,对应的值消费者可消费的分区 + // AssignPartition partitions表示分区数,返回值为map[消费者名称][]分区索引 AssignPartition(consumers []string, partitions int) map[string][]int } From 56e4d6ce90a8f3d11ba561b2ea2b2819f096b0f6 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Tue, 23 Apr 2024 23:10:20 +0800 Subject: [PATCH 06/12] =?UTF-8?q?=E6=B7=BB=E5=8A=A0license?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/partition_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/memory/partition_test.go b/memory/partition_test.go index 2c6dcd5..beca471 100644 --- a/memory/partition_test.go +++ b/memory/partition_test.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( From cb7cd7d098b41131ebc190b0a08e97142af3cf28 Mon Sep 17 00:00:00 2001 From: zhuwenliang Date: Mon, 29 Apr 2024 17:57:38 +0800 Subject: [PATCH 07/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/consumer.go | 5 +- memory/consumergroup.go | 12 +-- .../equaldivide/balancer.go | 8 +- .../equaldivide/balancer_test.go | 2 +- memory/mq.go | 17 ++-- memory/mq_test.go | 23 +++++ memory/partition.go | 6 +- memory/partition_test.go | 94 ++++++++++++++++--- memory/producer.go | 17 +--- memory/topic.go | 27 +++--- 10 files changed, 146 insertions(+), 65 deletions(-) create mode 100644 memory/mq_test.go diff --git a/memory/consumer.go b/memory/consumer.go index cc0c8af..ef472db 100644 --- a/memory/consumer.go +++ b/memory/consumer.go @@ -16,7 +16,6 @@ package memory import ( "context" - "errors" "log" "sync" "time" @@ -32,8 +31,6 @@ const ( limit = 25 ) -var ErrConsumerClose = errors.New("消费者已关闭") - type Consumer struct { locker sync.RWMutex name string @@ -144,7 +141,7 @@ func (c *Consumer) Close() error { c.once.Do(func() { c.closed = true c.reportCh <- &Event{ - Type: ExitGroup, + Type: ExitGroupEvent, Data: c.closeCh, } log.Printf("消费者 %s 准备关闭", c.name) diff --git a/memory/consumergroup.go b/memory/consumergroup.go index b6d9c06..4a1ead1 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -34,9 +34,9 @@ const ( msgChannelLength = 1000 defaultSleepTime = 100 * time.Millisecond - // ExitGroup 退出信号 - ExitGroup = "exit_group" - // ReportOffset 上报偏移量信号 + // ExitGroupEvent 退出事件 + ExitGroupEvent = "exit_group" + // ReportOffset 上报偏移量事件 ReportOffset = "report_offset" // Rejoin 通知consumer重新加入消费组 Rejoin = "rejoin" @@ -97,7 +97,7 @@ type Event struct { func (c *ConsumerGroup) Handler(name string, event *Event) { switch event.Type { - case ExitGroup: + case ExitGroupEvent: closeCh, _ := event.Data.(chan struct{}) c.ExitGroup(name, closeCh) case ReportOffset: @@ -119,7 +119,7 @@ func (c *ConsumerGroup) Handler(name string, event *Event) { } } -// ExitGroup 退出消费组 +// ExitGroupEvent 退出消费组 func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { // 把自己从消费组内摘除 for { @@ -264,7 +264,7 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { func (c *ConsumerGroup) HandleConsumerSignals(name string, reportCh chan *Event) { for event := range reportCh { c.Handler(name, event) - if event.Type == ExitGroup { + if event.Type == ExitGroupEvent { close(reportCh) return } diff --git a/memory/consumerpartitionassigner/equaldivide/balancer.go b/memory/consumerpartitionassigner/equaldivide/balancer.go index d8c7364..63bbbbd 100644 --- a/memory/consumerpartitionassigner/equaldivide/balancer.go +++ b/memory/consumerpartitionassigner/equaldivide/balancer.go @@ -14,9 +14,9 @@ package equaldivide -type Balancer struct{} +type Assigner struct{} -func (b *Balancer) AssignPartition(consumers []string, partitions int) map[string][]int { +func (b *Assigner) AssignPartition(consumers []string, partitions int) map[string][]int { result := make(map[string][]int) consumerCount := len(consumers) partitionPerConsumer := partitions / consumerCount @@ -45,6 +45,6 @@ func (b *Balancer) AssignPartition(consumers []string, partitions int) map[strin return result } -func NewBalancer() *Balancer { - return &Balancer{} +func NewAssigner() *Assigner { + return &Assigner{} } diff --git a/memory/consumerpartitionassigner/equaldivide/balancer_test.go b/memory/consumerpartitionassigner/equaldivide/balancer_test.go index 7dc188d..844b50c 100644 --- a/memory/consumerpartitionassigner/equaldivide/balancer_test.go +++ b/memory/consumerpartitionassigner/equaldivide/balancer_test.go @@ -22,7 +22,7 @@ import ( func TestBalancer_AssignPartition(t *testing.T) { t.Parallel() - balancer := NewBalancer() + balancer := NewAssigner() testcases := []struct { name string consumers []string diff --git a/memory/mq.go b/memory/mq.go index 3728aaf..134cf1a 100644 --- a/memory/mq.go +++ b/memory/mq.go @@ -63,9 +63,6 @@ func (m *MQ) CreateTopic(ctx context.Context, topic string, partitions int) erro if !ok { m.topics.Store(topic, NewTopic(topic, partitions)) } - if partitions <= 0 { - return errs.ErrInvalidPartition - } return nil } @@ -77,11 +74,11 @@ func (m *MQ) Producer(topic string) (mq.Producer, error) { } t, ok := m.topics.Load(topic) if !ok { - return nil, errs.ErrInvalidTopic + t = NewTopic(topic, defaultPartitions) + m.topics.Store(topic, t) } p := &Producer{ - locker: sync.RWMutex{}, - t: t, + t: t, } err := t.addProducer(p) if err != nil { @@ -99,13 +96,14 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { t, ok := m.topics.Load(topic) if !ok { t = NewTopic(topic, defaultPartitions) + m.topics.Store(topic, t) } group, ok := t.consumerGroups.Load(groupID) if !ok { group = &ConsumerGroup{ name: groupID, consumers: syncx.Map[string, *ConsumerMetaData]{}, - consumerPartitionBalancer: t.consumerPartitionBalancer, + consumerPartitionBalancer: t.consumerPartitionAssigner, partitions: t.partitions, balanceCh: make(chan struct{}, defaultBalanceChLen), status: StatusStable, @@ -152,7 +150,10 @@ func (m *MQ) DeleteTopics(ctx context.Context, topics ...string) error { for _, t := range topics { topic, ok := m.topics.Load(t) if ok { - topic.Close() + err := topic.Close() + if err != nil { + log.Printf("topic: %s关闭失败 %v", t, err) + } m.topics.Delete(t) } diff --git a/memory/mq_test.go b/memory/mq_test.go new file mode 100644 index 0000000..30026d8 --- /dev/null +++ b/memory/mq_test.go @@ -0,0 +1,23 @@ +package memory + +import ( + "github.com/ecodeclub/ekit/syncx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMQ(t *testing.T) { + t.Parallel() + // 测试调用consumer 和 producer 如果topic不存在就新建 + testmq := &MQ{ + topics: syncx.Map[string, *Topic]{}, + } + _, err := testmq.Consumer("test_topic", "group1") + require.NoError(t, err) + _, ok := testmq.topics.Load("test_topic") + assert.Equal(t, ok, true) + testmq.Producer("test_topic1") + _, ok = testmq.topics.Load("test_topic1") + assert.Equal(t, ok, true) +} diff --git a/memory/partition.go b/memory/partition.go index 1114aef..72f2000 100644 --- a/memory/partition.go +++ b/memory/partition.go @@ -44,11 +44,11 @@ func (p *Partition) append(msg *mq.Message) { _ = p.data.Append(msg) } -func (p *Partition) getBatch(cursor, limit int) []*mq.Message { +func (p *Partition) getBatch(offset, limit int) []*mq.Message { p.locker.RLock() defer p.locker.RUnlock() - wantLen := cursor + limit + wantLen := offset + limit length := min(wantLen, p.data.Len()) - res := p.data.AsSlice()[cursor:length] + res := p.data.AsSlice()[offset:length] return res } diff --git a/memory/partition_test.go b/memory/partition_test.go index beca471..4c2f405 100644 --- a/memory/partition_test.go +++ b/memory/partition_test.go @@ -15,6 +15,8 @@ package memory import ( + "strconv" + "sync" "testing" "github.com/ecodeclub/mq-api" @@ -25,33 +27,103 @@ func Test_Partition(t *testing.T) { t.Parallel() p := NewPartition() for i := 0; i < 5; i++ { - msg := &mq.Message{Partition: int64(i)} + msg := &mq.Message{Value: []byte(strconv.Itoa(i))} p.append(msg) } msgs := p.getBatch(2, 2) assert.Equal(t, []*mq.Message{ { - Partition: 2, - Offset: 2, + Value: []byte(strconv.Itoa(2)), + Offset: 2, }, { - Partition: 3, - Offset: 3, + Value: []byte(strconv.Itoa(3)), + Offset: 3, }, }, msgs) msgs = p.getBatch(2, 5) assert.Equal(t, []*mq.Message{ { - Partition: 2, - Offset: 2, + Value: []byte(strconv.Itoa(2)), + Offset: 2, }, { - Partition: 3, - Offset: 3, + Value: []byte(strconv.Itoa(3)), + Offset: 3, }, { - Partition: 4, - Offset: 4, + Value: []byte(strconv.Itoa(4)), + Offset: 4, }, }, msgs) + + // 测试多个goroutine往partition里写 + p2 := NewPartition() + wg := &sync.WaitGroup{} + for i := 0; i < 3; i++ { + wg.Add(1) + index := i * 5 + go func() { + defer wg.Done() + for j := index; j < index+5; j++ { + p2.append(&mq.Message{ + Value: []byte(strconv.Itoa(j)), + }) + } + }() + } + wg.Wait() + msgs = p2.getBatch(0, 16) + for idx := range msgs { + msgs[idx].Partition = 0 + msgs[idx].Offset = 0 + } + wantVal := []*mq.Message{ + { + Value: []byte("0"), + }, + { + Value: []byte("1"), + }, + { + Value: []byte("2"), + }, + { + Value: []byte("3"), + }, + { + Value: []byte("4"), + }, + { + Value: []byte("5"), + }, + { + Value: []byte("6"), + }, + { + Value: []byte("7"), + }, + { + Value: []byte("8"), + }, + { + Value: []byte("9"), + }, + { + Value: []byte("10"), + }, + { + Value: []byte("11"), + }, + { + Value: []byte("12"), + }, + { + Value: []byte("13"), + }, + { + Value: []byte("14"), + }, + } + assert.ElementsMatch(t, wantVal, msgs) } diff --git a/memory/producer.go b/memory/producer.go index c7361d4..2c929a5 100644 --- a/memory/producer.go +++ b/memory/producer.go @@ -16,24 +16,21 @@ package memory import ( "context" - "sync" - "github.com/ecodeclub/mq-api/internal/errs" + "sync/atomic" "github.com/ecodeclub/mq-api" ) type Producer struct { t *Topic - closed bool - locker sync.RWMutex + closed int32 } func (p *Producer) Produce(ctx context.Context, m *mq.Message) (*mq.ProducerResult, error) { if ctx.Err() != nil { return nil, ctx.Err() } - // 将partition设为 -1,按系统分配算法分配到某个分区 if p.isClosed() { return nil, errs.ErrProducerIsClosed } @@ -48,19 +45,15 @@ func (p *Producer) ProduceWithPartition(ctx context.Context, m *mq.Message, part if p.isClosed() { return nil, errs.ErrProducerIsClosed } - err := p.t.addMessage(m, int64(partition)) + err := p.t.addMessageWithPartition(m, int64(partition)) return &mq.ProducerResult{}, err } func (p *Producer) Close() error { - p.locker.Lock() - defer p.locker.Unlock() - p.closed = true + atomic.StoreInt32(&p.closed, 1) return nil } func (p *Producer) isClosed() bool { - p.locker.RLock() - defer p.locker.RUnlock() - return p.closed + return atomic.LoadInt32(&p.closed) == 1 } diff --git a/memory/topic.go b/memory/topic.go index 5f36542..74e30c4 100644 --- a/memory/topic.go +++ b/memory/topic.go @@ -26,7 +26,6 @@ import ( ) type Topic struct { - // 用[]*mq.Message表示一个分区 locker sync.RWMutex closed bool name string @@ -35,8 +34,8 @@ type Topic struct { // 消费组 consumerGroups syncx.Map[string, *ConsumerGroup] // 生产消息的时候获取分区号 - partitionIDGetter PartitionIDGetter - consumerPartitionBalancer ConsumerPartitionAssigner + producerPartitionIDGetter PartitionIDGetter + consumerPartitionAssigner ConsumerPartitionAssigner } type TopicOption func(t *Topic) @@ -44,8 +43,8 @@ func NewTopic(name string, partitions int) *Topic { t := &Topic{ name: name, consumerGroups: syncx.Map[string, *ConsumerGroup]{}, - consumerPartitionBalancer: equaldivide.NewBalancer(), - partitionIDGetter: &hash.Getter{Partitions: partitions}, + consumerPartitionAssigner: equaldivide.NewAssigner(), + producerPartitionIDGetter: &hash.Getter{Partitions: partitions}, } partitionList := make([]*Partition, 0, partitions) for i := 0; i < partitions; i++ { @@ -66,17 +65,13 @@ func (t *Topic) addProducer(producer mq.Producer) error { } // addMessage 往分区里面添加消息 -func (t *Topic) addMessage(msg *mq.Message, partition ...int64) error { - var partitionID int64 - partitionLen := len(partition) - switch partitionLen { - case 0: - partitionID = t.partitionIDGetter.PartitionID(string(msg.Key)) - case 1: - partitionID = partition[0] - default: - return errs.ErrInvalidPartition - } +// 发送消息 producer生成msg--->add +func (t *Topic) addMessage(msg *mq.Message) error { + partitionID := t.producerPartitionIDGetter.PartitionID(string(msg.Key)) + return t.addMessageWithPartition(msg, partitionID) + +} +func (t *Topic) addMessageWithPartition(msg *mq.Message, partitionID int64) error { if partitionID < 0 || int(partitionID) >= len(t.partitions) { return errs.ErrInvalidPartition } From 82136d305c846ef793dc689b1cbb4268024f1895 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Tue, 30 Apr 2024 21:30:35 +0800 Subject: [PATCH 08/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/consumer.go | 66 +++++++++-------- memory/consumergroup.go | 153 +++++++++++++++++++--------------------- memory/mq.go | 7 +- memory/mq_test.go | 6 +- memory/producer.go | 3 +- memory/topic.go | 3 +- 6 files changed, 120 insertions(+), 118 deletions(-) diff --git a/memory/consumer.go b/memory/consumer.go index ef472db..76f803e 100644 --- a/memory/consumer.go +++ b/memory/consumer.go @@ -61,55 +61,59 @@ func (c *Consumer) Consume(ctx context.Context) (*mq.Message, error) { } // 启动Consume -func (c *Consumer) Run() { +func (c *Consumer) eventLoop() { ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: log.Printf("消费者 %s 开始消费数据", c.name) - for idx, record := range c.partitionRecords { - msgs := c.partitions[record.Index].getBatch(record.Cursor, limit) - for _, msg := range msgs { - log.Printf("消费者 %s 消费数据 %v", c.name, msg) - c.msgCh <- msg - } - record.Cursor += len(msgs) - errCh := make(chan error, 1) - c.reportCh <- &Event{ - Type: ReportOffset, - Data: ReportData{ - Records: []PartitionRecord{record}, - ErrChan: errCh, - }, - } - err := <-errCh - if err != nil { - log.Printf("上报偏移量失败:%v", err) - break - } - close(errCh) - c.partitionRecords[idx] = record - } + c.consumerAndReport() log.Printf("消费者 %s 结束消费数据", c.name) case event, ok := <-c.receiveCh: if !ok { return } // 处理各种事件 - c.Handle(event) + c.handle(event) + } + } +} + +func (c *Consumer) consumerAndReport() { + for idx, record := range c.partitionRecords { + msgs := c.partitions[record.Index].getBatch(record.Offset, limit) + for _, msg := range msgs { + log.Printf("消费者 %s 消费数据 %v", c.name, msg) + c.msgCh <- msg + } + record.Offset += len(msgs) + errCh := make(chan error, 1) + c.reportCh <- &Event{ + Type: ReportOffsetEvent, + Data: ReportData{ + Records: []PartitionRecord{record}, + ErrChan: errCh, + }, + } + err := <-errCh + if err != nil { + log.Printf("上报偏移量失败:%v", err) + return } + close(errCh) + c.partitionRecords[idx] = record } } -func (c *Consumer) Handle(event *Event) { +func (c *Consumer) handle(event *Event) { switch event.Type { - // 服务端发起的重平衡事件 - case Rejoin: + // 服务端发起的重新加入事件 + case RejoinEvent: // 消费者上报消费进度 log.Printf("消费者 %s开始上报消费进度", c.name) c.reportCh <- &Event{ - Type: RejoinAck, + Type: RejoinAckEvent, Data: c.partitionRecords, } // 设置消费进度 @@ -118,9 +122,9 @@ func (c *Consumer) Handle(event *Event) { c.partitionRecords, _ = partitionInfo.Data.([]PartitionRecord) // 返回设置完成的信号 c.reportCh <- &Event{ - Type: PartitionNotifyAck, + Type: PartitionNotifyAckEvent, } - case Close: + case CloseEvent: c.Close() } } diff --git a/memory/consumergroup.go b/memory/consumergroup.go index 4a1ead1..d9991a4 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -17,6 +17,7 @@ package memory import ( "fmt" "log" + "sync" "sync/atomic" "time" @@ -36,17 +37,17 @@ const ( // ExitGroupEvent 退出事件 ExitGroupEvent = "exit_group" - // ReportOffset 上报偏移量事件 - ReportOffset = "report_offset" - // Rejoin 通知consumer重新加入消费组 - Rejoin = "rejoin" - // RejoinAck 表示客户端收到重新加入消费组的指令并将offset进行上报 - RejoinAck = "rejoin_ack" - Close = "close" - // PartitionNotify 下发分区情况事件 - PartitionNotify = "partition_notify" - // PartitionNotifyAck 下发分区情况确认 - PartitionNotifyAck = "partition_notify_ack" + // ReportOffsetEvent 上报偏移量事件 + ReportOffsetEvent = "report_offset" + // RejoinEvent 通知consumer重新加入消费组 + RejoinEvent = "rejoin" + // RejoinAckEvent 表示客户端收到重新加入消费组的指令并将offset进行上报 + RejoinAckEvent = "rejoin_ack" + CloseEvent = "close" + // PartitionNotifyEvent 下发分区情况事件 + PartitionNotifyEvent = "partition_notify" + // PartitionNotifyAckEvent 下发分区情况确认事件 + PartitionNotifyAckEvent = "partition_notify_ack" StatusStable = 1 // 稳定状态,可以正常的进行消费数据 StatusBalancing = 2 @@ -56,38 +57,29 @@ const ( type ConsumerGroup struct { name string // 存储消费者元数据,键为消费者的名称 - consumers syncx.Map[string, *ConsumerMetaData] + consumers syncx.Map[string, *Consumer] // 消费者平衡器 - consumerPartitionBalancer ConsumerPartitionAssigner + consumerPartitionAssigner ConsumerPartitionAssigner // 分区消费记录 partitionRecords *syncx.Map[int, PartitionRecord] // 分区 partitions []*Partition status int32 - // 用于接受在重平衡阶段channel的返回数据 - balanceCh chan struct{} + balanceCh chan struct{} + once sync.Once } type PartitionRecord struct { // 属于哪个分区 Index int // 消费进度 - Cursor int + Offset int } type ReportData struct { Records []PartitionRecord ErrChan chan error } -type ConsumerMetaData struct { - // 用于消费者上报数据,如退出或加入消费组,上报消费者消费分区的偏移量 - reportCh chan *Event - // 用于消费组给消费者下发数据,如下发开始重平衡开始通知,告知消费者可以消费的channel - receiveCh chan *Event - // 消费者的名字 - name string -} - type Event struct { // 事件类型 Type string @@ -95,32 +87,32 @@ type Event struct { Data any } -func (c *ConsumerGroup) Handler(name string, event *Event) { +func (c *ConsumerGroup) eventHandler(name string, event *Event) { switch event.Type { case ExitGroupEvent: closeCh, _ := event.Data.(chan struct{}) - c.ExitGroup(name, closeCh) - case ReportOffset: + c.exitGroup(name, closeCh) + case ReportOffsetEvent: data, _ := event.Data.(ReportData) var err error - err = c.ReportOffset(data.Records) + err = c.reportOffset(data.Records) data.ErrChan <- err log.Printf("消费者%s上报offset成功", name) - case RejoinAck: + case RejoinAckEvent: // consumer响应重平衡信号返回的数据,返回的是当前所有分区的偏移量 records, _ := event.Data.([]PartitionRecord) // 不管上报成不成功 - _ = c.ReportOffset(records) + _ = c.reportOffset(records) log.Printf("消费者%s成功接受到重平衡信号,并上报offset", name) c.balanceCh <- struct{}{} - case PartitionNotifyAck: + case PartitionNotifyAckEvent: log.Printf("消费者%s 成功设置分区信息", name) c.balanceCh <- struct{}{} } } // ExitGroupEvent 退出消费组 -func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { +func (c *ConsumerGroup) exitGroup(name string, closeCh chan struct{}) { // 把自己从消费组内摘除 for { if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { @@ -138,8 +130,8 @@ func (c *ConsumerGroup) ExitGroup(name string, closeCh chan struct{}) { } } -// ReportOffset 上报偏移量 -func (c *ConsumerGroup) ReportOffset(records []PartitionRecord) error { +// ReportOffsetEvent 上报偏移量 +func (c *ConsumerGroup) reportOffset(records []PartitionRecord) error { if atomic.LoadInt32(&c.status) != StatusStable { return ErrReportOffsetFail } @@ -150,14 +142,16 @@ func (c *ConsumerGroup) ReportOffset(records []PartitionRecord) error { } func (c *ConsumerGroup) Close() { - c.consumers.Range(func(key string, value *ConsumerMetaData) bool { - value.receiveCh <- &Event{ - Type: Close, - } - return true + c.once.Do(func() { + c.consumers.Range(func(key string, value *Consumer) bool { + value.receiveCh <- &Event{ + Type: CloseEvent, + } + return true + }) + // 等待一秒退出完成 + time.Sleep(1 * time.Second) }) - // 等待一秒退出完成 - time.Sleep(1 * time.Second) } // reBalance 单独使用该方法是并发不安全的 @@ -167,10 +161,10 @@ func (c *ConsumerGroup) reBalance() { length := 0 consumers := make([]string, 0, consumerCap) log.Println("开始给每个消费者,重平衡信号") - c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + c.consumers.Range(func(key string, value *Consumer) bool { log.Printf("开始通知消费者%s", key) value.receiveCh <- &Event{ - Type: Rejoin, + Type: RejoinEvent, } consumers = append(consumers, key) length++ @@ -183,39 +177,40 @@ func (c *ConsumerGroup) reBalance() { select { case <-c.balanceCh: number++ - if number == length { - // 接收到所有信号 - log.Println("所有消费者已经接受到重平衡请求,并上报了消费进度") - consumerMap := c.consumerPartitionBalancer.AssignPartition(consumers, len(c.partitions)) - // 通知所有消费者分配 - log.Println("开始分配分区") - for consumerName, partitions := range consumerMap { - // 查找消费者所属的channel - log.Printf("消费者 %s 消费 %v 分区", consumerName, partitions) - consumer, ok := c.consumers.Load(consumerName) - if ok { - // 往每个消费者的receive_channel发送partition的信息 - records := make([]PartitionRecord, 0, len(partitions)) - for _, p := range partitions { - record, ok := c.partitionRecords.Load(p) - if ok { - records = append(records, record) - } - } - consumer.receiveCh <- &Event{ - Type: PartitionNotify, - Data: records, + if number != length { + continue + } + // 接收到所有信号 + log.Println("所有消费者已经接受到重平衡请求,并上报了消费进度") + consumerMap := c.consumerPartitionAssigner.AssignPartition(consumers, len(c.partitions)) + // 通知所有消费者分配 + log.Println("开始分配分区") + for consumerName, partitions := range consumerMap { + // 查找消费者所属的channel + log.Printf("消费者 %s 消费 %v 分区", consumerName, partitions) + consumer, ok := c.consumers.Load(consumerName) + if ok { + // 往每个消费者的receive_channel发送partition的信息 + records := make([]PartitionRecord, 0, len(partitions)) + for _, p := range partitions { + record, ok := c.partitionRecords.Load(p) + if ok { + records = append(records, record) } - // 等待消费者接收到并保存 - <-c.balanceCh - } + consumer.receiveCh <- &Event{ + Type: PartitionNotifyEvent, + Data: records, + } + // 等待消费者接收到并保存 + <-c.balanceCh + } - log.Println("重平衡结束") - return } + log.Println("重平衡结束") + return default: - + time.Sleep(defaultSleepTime) } } log.Println("重平衡结束") @@ -229,7 +224,7 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { continue } var length int - c.consumers.Range(func(key string, value *ConsumerMetaData) bool { + c.consumers.Range(func(key string, value *Consumer) bool { length++ return true }) @@ -245,13 +240,13 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { partitionRecords: []PartitionRecord{}, closeCh: make(chan struct{}), } - c.consumers.Store(name, &ConsumerMetaData{ + c.consumers.Store(name, &Consumer{ reportCh: reportCh, receiveCh: receiveCh, name: name, }) - go c.HandleConsumerSignals(name, reportCh) - go consumer.Run() + go c.handleConsumerEvents(name, reportCh) + go consumer.eventLoop() log.Printf("新建消费者 %s", name) // 重平衡分配分区 c.reBalance() @@ -260,10 +255,10 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { } } -// HandleConsumerSignals 处理消费者上报的事件 -func (c *ConsumerGroup) HandleConsumerSignals(name string, reportCh chan *Event) { +// handleConsumerEvents 处理消费者上报的事件 +func (c *ConsumerGroup) handleConsumerEvents(name string, reportCh chan *Event) { for event := range reportCh { - c.Handler(name, event) + c.eventHandler(name, event) if event.Type == ExitGroupEvent { close(reportCh) return diff --git a/memory/mq.go b/memory/mq.go index 134cf1a..62fe571 100644 --- a/memory/mq.go +++ b/memory/mq.go @@ -102,8 +102,8 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { if !ok { group = &ConsumerGroup{ name: groupID, - consumers: syncx.Map[string, *ConsumerMetaData]{}, - consumerPartitionBalancer: t.consumerPartitionAssigner, + consumers: syncx.Map[string, *Consumer]{}, + consumerPartitionAssigner: t.consumerPartitionAssigner, partitions: t.partitions, balanceCh: make(chan struct{}, defaultBalanceChLen), status: StatusStable, @@ -113,7 +113,7 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { for idx := range t.partitions { partitionRecords.Store(idx, PartitionRecord{ Index: idx, - Cursor: 0, + Offset: 0, }) } group.partitionRecords = &partitionRecords @@ -153,6 +153,7 @@ func (m *MQ) DeleteTopics(ctx context.Context, topics ...string) error { err := topic.Close() if err != nil { log.Printf("topic: %s关闭失败 %v", t, err) + continue } m.topics.Delete(t) } diff --git a/memory/mq_test.go b/memory/mq_test.go index 30026d8..a10f2ff 100644 --- a/memory/mq_test.go +++ b/memory/mq_test.go @@ -1,10 +1,11 @@ package memory import ( + "testing" + "github.com/ecodeclub/ekit/syncx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" ) func TestMQ(t *testing.T) { @@ -17,7 +18,8 @@ func TestMQ(t *testing.T) { require.NoError(t, err) _, ok := testmq.topics.Load("test_topic") assert.Equal(t, ok, true) - testmq.Producer("test_topic1") + _, err = testmq.Producer("test_topic1") + require.NoError(t, err) _, ok = testmq.topics.Load("test_topic1") assert.Equal(t, ok, true) } diff --git a/memory/producer.go b/memory/producer.go index 2c929a5..8bdd006 100644 --- a/memory/producer.go +++ b/memory/producer.go @@ -16,9 +16,10 @@ package memory import ( "context" - "github.com/ecodeclub/mq-api/internal/errs" "sync/atomic" + "github.com/ecodeclub/mq-api/internal/errs" + "github.com/ecodeclub/mq-api" ) diff --git a/memory/topic.go b/memory/topic.go index 74e30c4..8bb76b5 100644 --- a/memory/topic.go +++ b/memory/topic.go @@ -65,12 +65,11 @@ func (t *Topic) addProducer(producer mq.Producer) error { } // addMessage 往分区里面添加消息 -// 发送消息 producer生成msg--->add func (t *Topic) addMessage(msg *mq.Message) error { partitionID := t.producerPartitionIDGetter.PartitionID(string(msg.Key)) return t.addMessageWithPartition(msg, partitionID) - } + func (t *Topic) addMessageWithPartition(msg *mq.Message, partitionID int64) error { if partitionID < 0 || int(partitionID) >= len(t.partitions) { return errs.ErrInvalidPartition From e5aaa490271d7c4701e452e25d9199d371f57bc3 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Tue, 30 Apr 2024 21:46:45 +0800 Subject: [PATCH 09/12] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20License?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/mq_test.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/memory/mq_test.go b/memory/mq_test.go index a10f2ff..85a7d42 100644 --- a/memory/mq_test.go +++ b/memory/mq_test.go @@ -1,3 +1,17 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package memory import ( From 9400e6fef4d10948090b2fae9b0157831ee4d645 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Sun, 5 May 2024 22:20:58 +0800 Subject: [PATCH 10/12] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B6=88=E8=B4=B9?= =?UTF-8?q?=E7=BB=84=EF=BC=8C=E5=85=B3=E9=97=AD=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/e2e/base_test.go | 6 +- memory/consumer.go | 13 ++- memory/consumergroup.go | 150 ++++++++++++++++++++--------------- memory/consumergroup_test.go | 84 ++++++++++++++++++++ memory/mq.go | 11 ++- memory/partition_test.go | 6 +- memory/producer.go | 29 ++++--- memory/topic.go | 4 +- memory/topic_test.go | 55 +++++++++++++ 9 files changed, 271 insertions(+), 87 deletions(-) create mode 100644 memory/consumergroup_test.go create mode 100644 memory/topic_test.go diff --git a/internal/e2e/base_test.go b/internal/e2e/base_test.go index 9d063c3..8fb8c96 100644 --- a/internal/e2e/base_test.go +++ b/internal/e2e/base_test.go @@ -260,6 +260,10 @@ func (b *TestSuite) TestMQ_Producer() { err := b.messageQueue.CreateTopic(context.Background(), unknownTopic, 1) require.NoError(t, err) require.NoError(t, b.messageQueue.DeleteTopics(context.Background(), unknownTopic)) + // 如果topic不存在会默认创建,不会报错 + notExistTopic := "notExistTopic" + _, err = b.messageQueue.Producer(notExistTopic) + require.NoError(t, err) } func (b *TestSuite) TestMQ_Consumer() { @@ -270,7 +274,7 @@ func (b *TestSuite) TestMQ_Consumer() { err := b.messageQueue.CreateTopic(context.Background(), topic, 1) require.NoError(t, err) require.NoError(t, b.messageQueue.DeleteTopics(context.Background(), topic)) - + // 如果topic不存在会默认创建,不会报错 _, err = b.messageQueue.Consumer(topic, groupID) require.NoError(t, err) } diff --git a/memory/consumer.go b/memory/consumer.go index 76f803e..3d3c2ec 100644 --- a/memory/consumer.go +++ b/memory/consumer.go @@ -68,7 +68,7 @@ func (c *Consumer) eventLoop() { select { case <-ticker.C: log.Printf("消费者 %s 开始消费数据", c.name) - c.consumerAndReport() + c.consumeAndReport() log.Printf("消费者 %s 结束消费数据", c.name) case event, ok := <-c.receiveCh: if !ok { @@ -80,7 +80,7 @@ func (c *Consumer) eventLoop() { } } -func (c *Consumer) consumerAndReport() { +func (c *Consumer) consumeAndReport() { for idx, record := range c.partitionRecords { msgs := c.partitions[record.Index].getBatch(record.Offset, limit) for _, msg := range msgs { @@ -125,7 +125,14 @@ func (c *Consumer) handle(event *Event) { Type: PartitionNotifyAckEvent, } case CloseEvent: - c.Close() + // 未返回错误不做处理 + _ = c.Close() + ch, ok := event.Data.(chan struct{}) + if !ok { + return + } + ch <- struct{}{} + } } diff --git a/memory/consumergroup.go b/memory/consumergroup.go index d9991a4..ca7febc 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -27,7 +27,10 @@ import ( "github.com/pkg/errors" ) -var ErrReportOffsetFail = errors.New("非平衡状态,无法上报偏移量") +var ( + ErrReportOffsetFail = errors.New("非平衡状态,无法上报偏移量") + ErrConsumerGroupClosed = errors.New("消费组已经关闭") +) const ( consumerCap = 16 @@ -35,22 +38,27 @@ const ( msgChannelLength = 1000 defaultSleepTime = 100 * time.Millisecond - // ExitGroupEvent 退出事件 + // ExitGroupEvent consumer=>consumer_group 表示消费者退出消费组的事件 ExitGroupEvent = "exit_group" - // ReportOffsetEvent 上报偏移量事件 + // ReportOffsetEvent consumer=>consumer_group 表示消费者向消费组上报消费进度事件 ReportOffsetEvent = "report_offset" - // RejoinEvent 通知consumer重新加入消费组 + // RejoinEvent consumer_group=>consumer 表示消费组通知消费者重新加入消费组 RejoinEvent = "rejoin" - // RejoinAckEvent 表示客户端收到重新加入消费组的指令并将offset进行上报 + // RejoinAckEvent consumer=>consumer_group 表示消费者收到重新加入消费组的指令并将offset进行上报 RejoinAckEvent = "rejoin_ack" - CloseEvent = "close" - // PartitionNotifyEvent 下发分区情况事件 + // CloseEvent consumer_group=>consumer 表示消费组关闭所有消费者,向所有消费者发出关闭事件 + CloseEvent = "close" + // PartitionNotifyEvent consumer_group=>consumer 表示消费组向消费者下发分区情况 PartitionNotifyEvent = "partition_notify" - // PartitionNotifyAckEvent 下发分区情况确认事件 + // PartitionNotifyAckEvent consumer=>consumer_group 表示消费者对消费组下发分区情况事件的确认 PartitionNotifyAckEvent = "partition_notify_ack" StatusStable = 1 // 稳定状态,可以正常的进行消费数据 StatusBalancing = 2 + // 消费组关闭 + StatusStop = 3 + // 一个消费者正在退出消费组 + StatusStopping = 4 ) // ConsumerGroup 表示消费组是并发安全的 @@ -115,7 +123,8 @@ func (c *ConsumerGroup) eventHandler(name string, event *Event) { func (c *ConsumerGroup) exitGroup(name string, closeCh chan struct{}) { // 把自己从消费组内摘除 for { - if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { + if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) && + !atomic.CompareAndSwapInt32(&c.status, StatusStop, StatusStopping) { time.Sleep(defaultSleepTime) continue } @@ -125,14 +134,17 @@ func (c *ConsumerGroup) exitGroup(name string, closeCh chan struct{}) { log.Printf("给消费者 %s 发送退出确认信号", name) close(closeCh) log.Printf("消费者 %s 成功退出消费组", name) - atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) + if !atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) { + atomic.CompareAndSwapInt32(&c.status, StatusStopping, StatusStop) + } return } } // ReportOffsetEvent 上报偏移量 func (c *ConsumerGroup) reportOffset(records []PartitionRecord) error { - if atomic.LoadInt32(&c.status) != StatusStable { + status := atomic.LoadInt32(&c.status) + if status != StatusStable && status != StatusStop { return ErrReportOffsetFail } for _, record := range records { @@ -143,14 +155,28 @@ func (c *ConsumerGroup) reportOffset(records []PartitionRecord) error { func (c *ConsumerGroup) Close() { c.once.Do(func() { - c.consumers.Range(func(key string, value *Consumer) bool { - value.receiveCh <- &Event{ - Type: CloseEvent, + for { + log.Println("开始关闭", c.status) + if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusStop) { + time.Sleep(defaultSleepTime) + continue } - return true - }) - // 等待一秒退出完成 - time.Sleep(1 * time.Second) + log.Println("正在关闭", c.status) + c.close() + return + } + }) +} + +func (c *ConsumerGroup) close() { + c.consumers.Range(func(key string, value *Consumer) bool { + ch := make(chan struct{}) + value.receiveCh <- &Event{ + Type: CloseEvent, + Data: ch, + } + <-ch + return true }) } @@ -172,57 +198,59 @@ func (c *ConsumerGroup) reBalance() { return true }) number := 0 + log.Println("xxxxxxxxxx长度", length) // 等待所有消费者都接收到信号,并上报自己offset for length > 0 { - select { - case <-c.balanceCh: - number++ - if number != length { - continue - } - // 接收到所有信号 - log.Println("所有消费者已经接受到重平衡请求,并上报了消费进度") - consumerMap := c.consumerPartitionAssigner.AssignPartition(consumers, len(c.partitions)) - // 通知所有消费者分配 - log.Println("开始分配分区") - for consumerName, partitions := range consumerMap { - // 查找消费者所属的channel - log.Printf("消费者 %s 消费 %v 分区", consumerName, partitions) - consumer, ok := c.consumers.Load(consumerName) - if ok { - // 往每个消费者的receive_channel发送partition的信息 - records := make([]PartitionRecord, 0, len(partitions)) - for _, p := range partitions { - record, ok := c.partitionRecords.Load(p) - if ok { - records = append(records, record) - } - } - consumer.receiveCh <- &Event{ - Type: PartitionNotifyEvent, - Data: records, + <-c.balanceCh + number++ + if number != length { + log.Println("xxxxxxxxxx number", number) + continue + } + // 接收到所有信号 + log.Println("所有消费者已经接受到重平衡请求,并上报了消费进度") + consumerMap := c.consumerPartitionAssigner.AssignPartition(consumers, len(c.partitions)) + // 通知所有消费者分配 + log.Println("开始分配分区") + for consumerName, partitions := range consumerMap { + // 查找消费者所属的channel + log.Printf("消费者 %s 消费 %v 分区", consumerName, partitions) + consumer, ok := c.consumers.Load(consumerName) + if ok { + // 往每个消费者的receive_channel发送partition的信息 + records := make([]PartitionRecord, 0, len(partitions)) + for _, p := range partitions { + record, ok := c.partitionRecords.Load(p) + if ok { + records = append(records, record) } - // 等待消费者接收到并保存 - <-c.balanceCh - } + consumer.receiveCh <- &Event{ + Type: PartitionNotifyEvent, + Data: records, + } + // 等待消费者接收到并保存 + <-c.balanceCh + } - log.Println("重平衡结束") - return - default: - time.Sleep(defaultSleepTime) } + log.Println("重平衡结束") + return } - log.Println("重平衡结束") } // JoinGroup 加入消费组 -func (c *ConsumerGroup) JoinGroup() *Consumer { +func (c *ConsumerGroup) JoinGroup() (*Consumer, error) { for { + + if atomic.LoadInt32(&c.status) > StatusBalancing { + return nil, ErrConsumerGroupClosed + } if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusBalancing) { time.Sleep(defaultSleepTime) continue } + var length int c.consumers.Range(func(key string, value *Consumer) bool { length++ @@ -240,23 +268,19 @@ func (c *ConsumerGroup) JoinGroup() *Consumer { partitionRecords: []PartitionRecord{}, closeCh: make(chan struct{}), } - c.consumers.Store(name, &Consumer{ - reportCh: reportCh, - receiveCh: receiveCh, - name: name, - }) - go c.handleConsumerEvents(name, reportCh) + c.consumers.Store(name, consumer) + go c.consumerEventsHandler(name, reportCh) go consumer.eventLoop() log.Printf("新建消费者 %s", name) // 重平衡分配分区 c.reBalance() atomic.CompareAndSwapInt32(&c.status, StatusBalancing, StatusStable) - return consumer + return consumer, nil } } -// handleConsumerEvents 处理消费者上报的事件 -func (c *ConsumerGroup) handleConsumerEvents(name string, reportCh chan *Event) { +// consumerEventsHandler 处理消费者上报的事件 +func (c *ConsumerGroup) consumerEventsHandler(name string, reportCh chan *Event) { for event := range reportCh { c.eventHandler(name, event) if event.Type == ExitGroupEvent { diff --git a/memory/consumergroup_test.go b/memory/consumergroup_test.go new file mode 100644 index 0000000..fa5e809 --- /dev/null +++ b/memory/consumergroup_test.go @@ -0,0 +1,84 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memory + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ecodeclub/ekit/list" + "github.com/ecodeclub/ekit/syncx" + "github.com/ecodeclub/mq-api/memory/consumerpartitionassigner/equaldivide" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// 测试场景: 不断有 消费者加入 消费组,最后达成的效果,调用consumerGroup的close方法成功之后,consumerGroup里面没有consumer存在且所有的consumer都是关闭的状态 + +func TestConsumerGroup_Close(t *testing.T) { + t.Parallel() + cg := ConsumerGroup{ + name: "test_group", + consumers: syncx.Map[string, *Consumer]{}, + consumerPartitionAssigner: equaldivide.NewAssigner(), + partitions: []*Partition{ + NewPartition(), + NewPartition(), + NewPartition(), + }, + balanceCh: make(chan struct{}, defaultBalanceChLen), + status: StatusStable, + } + partitionRecords := syncx.Map[int, PartitionRecord]{} + for idx := range cg.partitions { + partitionRecords.Store(idx, PartitionRecord{ + Index: idx, + Offset: 0, + }) + } + cg.partitionRecords = &partitionRecords + var wg sync.WaitGroup + + consumerGroups := list.NewArrayList[*Consumer](30) + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c, err := cg.JoinGroup() + if err != nil { + assert.Equal(t, ErrConsumerGroupClosed, err) + } + err = consumerGroups.Append(c) + require.NoError(t, err) + }() + } + time.Sleep(100 * time.Millisecond) + cg.Close() + wg.Wait() + // consumerGroup中没有消费者 + var flag atomic.Bool + cg.consumers.Range(func(key string, value *Consumer) bool { + flag.Store(true) + return true + }) + assert.False(t, flag.Load()) + // 所有加入的消费者都是关闭状态 + cg.consumers.Range(func(key string, value *Consumer) bool { + assert.True(t, value.closed) + return true + }) +} diff --git a/memory/mq.go b/memory/mq.go index 62fe571..03ec30b 100644 --- a/memory/mq.go +++ b/memory/mq.go @@ -61,7 +61,7 @@ func (m *MQ) CreateTopic(ctx context.Context, topic string, partitions int) erro } _, ok := m.topics.Load(topic) if !ok { - m.topics.Store(topic, NewTopic(topic, partitions)) + m.topics.Store(topic, newTopic(topic, partitions)) } return nil } @@ -74,7 +74,7 @@ func (m *MQ) Producer(topic string) (mq.Producer, error) { } t, ok := m.topics.Load(topic) if !ok { - t = NewTopic(topic, defaultPartitions) + t = newTopic(topic, defaultPartitions) m.topics.Store(topic, t) } p := &Producer{ @@ -95,7 +95,7 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { } t, ok := m.topics.Load(topic) if !ok { - t = NewTopic(topic, defaultPartitions) + t = newTopic(topic, defaultPartitions) m.topics.Store(topic, t) } group, ok := t.consumerGroups.Load(groupID) @@ -118,7 +118,10 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { } group.partitionRecords = &partitionRecords } - consumer := group.JoinGroup() + consumer, err := group.JoinGroup() + if err != nil { + return nil, err + } t.consumerGroups.Store(groupID, group) return consumer, nil } diff --git a/memory/partition_test.go b/memory/partition_test.go index 4c2f405..37b2e97 100644 --- a/memory/partition_test.go +++ b/memory/partition_test.go @@ -56,7 +56,11 @@ func Test_Partition(t *testing.T) { Offset: 4, }, }, msgs) +} +// 测试多个goroutine往同一个队列中发送消息 +func Test_PartitionConcurrent(t *testing.T) { + t.Parallel() // 测试多个goroutine往partition里写 p2 := NewPartition() wg := &sync.WaitGroup{} @@ -73,7 +77,7 @@ func Test_Partition(t *testing.T) { }() } wg.Wait() - msgs = p2.getBatch(0, 16) + msgs := p2.getBatch(0, 16) for idx := range msgs { msgs[idx].Partition = 0 msgs[idx].Offset = 0 diff --git a/memory/producer.go b/memory/producer.go index 8bdd006..6d88723 100644 --- a/memory/producer.go +++ b/memory/producer.go @@ -16,7 +16,7 @@ package memory import ( "context" - "sync/atomic" + "sync" "github.com/ecodeclub/mq-api/internal/errs" @@ -24,37 +24,40 @@ import ( ) type Producer struct { + mu sync.RWMutex t *Topic - closed int32 + closed bool } func (p *Producer) Produce(ctx context.Context, m *mq.Message) (*mq.ProducerResult, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return nil, errs.ErrProducerIsClosed + } if ctx.Err() != nil { return nil, ctx.Err() } - if p.isClosed() { - return nil, errs.ErrProducerIsClosed - } err := p.t.addMessage(m) return &mq.ProducerResult{}, err } func (p *Producer) ProduceWithPartition(ctx context.Context, m *mq.Message, partition int) (*mq.ProducerResult, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return nil, errs.ErrProducerIsClosed + } if ctx.Err() != nil { return nil, ctx.Err() } - if p.isClosed() { - return nil, errs.ErrProducerIsClosed - } err := p.t.addMessageWithPartition(m, int64(partition)) return &mq.ProducerResult{}, err } func (p *Producer) Close() error { - atomic.StoreInt32(&p.closed, 1) + p.mu.Lock() + defer p.mu.Unlock() + p.closed = true return nil } - -func (p *Producer) isClosed() bool { - return atomic.LoadInt32(&p.closed) == 1 -} diff --git a/memory/topic.go b/memory/topic.go index 8bb76b5..e967f76 100644 --- a/memory/topic.go +++ b/memory/topic.go @@ -37,9 +37,8 @@ type Topic struct { producerPartitionIDGetter PartitionIDGetter consumerPartitionAssigner ConsumerPartitionAssigner } -type TopicOption func(t *Topic) -func NewTopic(name string, partitions int) *Topic { +func newTopic(name string, partitions int) *Topic { t := &Topic{ name: name, consumerGroups: syncx.Map[string, *ConsumerGroup]{}, @@ -85,6 +84,7 @@ func (t *Topic) Close() error { t.locker.Lock() defer t.locker.Unlock() if !t.closed { + t.closed = true t.consumerGroups.Range(func(key string, value *ConsumerGroup) bool { value.Close() return true diff --git a/memory/topic_test.go b/memory/topic_test.go new file mode 100644 index 0000000..272b66e --- /dev/null +++ b/memory/topic_test.go @@ -0,0 +1,55 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memory + +import ( + "context" + "testing" + + "github.com/ecodeclub/mq-api" + "github.com/ecodeclub/mq-api/internal/errs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTopic_Close(t *testing.T) { + t.Parallel() + topic := newTopic("test_topic", 3) + p1 := &Producer{ + t: topic, + } + p2 := &Producer{ + t: topic, + } + p3 := &Producer{ + t: topic, + } + err := topic.addProducer(p1) + require.NoError(t, err) + err = topic.addProducer(p2) + require.NoError(t, err) + err = topic.Close() + require.NoError(t, err) + require.Equal(t, true, topic.closed) + err = topic.Close() + require.NoError(t, err) + require.Equal(t, true, topic.closed) + err = topic.addProducer(p3) + assert.Equal(t, errs.ErrMQIsClosed, err) + _, err = p1.Produce(context.Background(), &mq.Message{ + Value: []byte("1"), + }) + assert.Equal(t, errs.ErrProducerIsClosed, err) +} From 9db0cb26d126b806ab4054bb08b71708e3f6a612 Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Sun, 5 May 2024 22:35:44 +0800 Subject: [PATCH 11/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memory/consumergroup.go | 2 -- memory/consumergroup_test.go | 23 +++++++++++------------ 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/memory/consumergroup.go b/memory/consumergroup.go index ca7febc..365fa74 100644 --- a/memory/consumergroup.go +++ b/memory/consumergroup.go @@ -156,12 +156,10 @@ func (c *ConsumerGroup) reportOffset(records []PartitionRecord) error { func (c *ConsumerGroup) Close() { c.once.Do(func() { for { - log.Println("开始关闭", c.status) if !atomic.CompareAndSwapInt32(&c.status, StatusStable, StatusStop) { time.Sleep(defaultSleepTime) continue } - log.Println("正在关闭", c.status) c.close() return } diff --git a/memory/consumergroup_test.go b/memory/consumergroup_test.go index fa5e809..f07b960 100644 --- a/memory/consumergroup_test.go +++ b/memory/consumergroup_test.go @@ -15,23 +15,20 @@ package memory import ( + "github.com/ecodeclub/ekit/syncx" + "github.com/ecodeclub/mq-api/memory/consumerpartitionassigner/equaldivide" + "github.com/stretchr/testify/assert" "sync" "sync/atomic" "testing" "time" - - "github.com/ecodeclub/ekit/list" - "github.com/ecodeclub/ekit/syncx" - "github.com/ecodeclub/mq-api/memory/consumerpartitionassigner/equaldivide" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // 测试场景: 不断有 消费者加入 消费组,最后达成的效果,调用consumerGroup的close方法成功之后,consumerGroup里面没有consumer存在且所有的consumer都是关闭的状态 func TestConsumerGroup_Close(t *testing.T) { t.Parallel() - cg := ConsumerGroup{ + cg := &ConsumerGroup{ name: "test_group", consumers: syncx.Map[string, *Consumer]{}, consumerPartitionAssigner: equaldivide.NewAssigner(), @@ -52,18 +49,20 @@ func TestConsumerGroup_Close(t *testing.T) { } cg.partitionRecords = &partitionRecords var wg sync.WaitGroup - - consumerGroups := list.NewArrayList[*Consumer](30) - for i := 0; i < 100; i++ { + mu := &sync.RWMutex{} + consumerGroups := make([]*Consumer, 0, 100) + for i := 0; i < 3; i++ { wg.Add(1) go func() { defer wg.Done() c, err := cg.JoinGroup() if err != nil { assert.Equal(t, ErrConsumerGroupClosed, err) + return } - err = consumerGroups.Append(c) - require.NoError(t, err) + mu.Lock() + consumerGroups = append(consumerGroups, c) + mu.Unlock() }() } time.Sleep(100 * time.Millisecond) From b052a91c4cb5cf2a16be5bcd56ad6614b722992f Mon Sep 17 00:00:00 2001 From: zwl <1633720889@qq.com> Date: Sun, 5 May 2024 22:35:56 +0800 Subject: [PATCH 12/12] make fmt --- memory/consumergroup_test.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/memory/consumergroup_test.go b/memory/consumergroup_test.go index f07b960..a0730d2 100644 --- a/memory/consumergroup_test.go +++ b/memory/consumergroup_test.go @@ -15,13 +15,14 @@ package memory import ( - "github.com/ecodeclub/ekit/syncx" - "github.com/ecodeclub/mq-api/memory/consumerpartitionassigner/equaldivide" - "github.com/stretchr/testify/assert" "sync" "sync/atomic" "testing" "time" + + "github.com/ecodeclub/ekit/syncx" + "github.com/ecodeclub/mq-api/memory/consumerpartitionassigner/equaldivide" + "github.com/stretchr/testify/assert" ) // 测试场景: 不断有 消费者加入 消费组,最后达成的效果,调用consumerGroup的close方法成功之后,consumerGroup里面没有consumer存在且所有的consumer都是关闭的状态