From b3c489cdada29f1fd2388b0da313492998f81621 Mon Sep 17 00:00:00 2001 From: Longyue Li Date: Mon, 30 Oct 2023 10:37:24 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=AE=8C=E5=96=84e2e=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B=20(#13)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- e2e/base_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++---- kafka/mq.go | 17 +++++++++++++++++ mqerr/error.go | 2 ++ 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/e2e/base_test.go b/e2e/base_test.go index 74c3a50..f7ff251 100644 --- a/e2e/base_test.go +++ b/e2e/base_test.go @@ -72,8 +72,7 @@ func (b *TestSuite) SetupSuite() { func (b *TestSuite) newProducersAndConsumers(t *testing.T, topic string, partitions int, p producerInfo, c consumerInfo) ([]mq.Producer, []mq.Consumer) { t.Helper() - err := b.messageQueue.CreateTopic(context.Background(), topic, partitions) - require.NoError(t, err) + _ = b.messageQueue.CreateTopic(context.Background(), topic, partitions) producers := make([]mq.Producer, 0, p.Num) for i := 0; i < p.Num; i++ { @@ -188,6 +187,19 @@ func (b *TestSuite) TestMQ_CreateTopic() { require.NoError(t, b.messageQueue.DeleteTopics(context.Background(), validTopic1, validTopic2)) }) + + t.Run("重复创建Topic", func(t *testing.T) { + t.Parallel() + + createdTopic, partitions := "createdTopic", 1 + err := b.messageQueue.CreateTopic(context.Background(), createdTopic, partitions) + require.NoError(t, err, createdTopic) + + err = b.messageQueue.CreateTopic(context.Background(), createdTopic, partitions) + require.Error(t, err) + + require.NoError(t, b.messageQueue.DeleteTopics(context.Background(), createdTopic)) + }) } func (b *TestSuite) TestMQ_DeleteTopics() { @@ -239,6 +251,32 @@ func (b *TestSuite) TestMQ_CreateTopicAndDeleteTopics() { require.NoError(t, eg.Wait()) } +func (b *TestSuite) TestMQ_Producer() { + t := b.T() + t.Parallel() + + unknownTopic := "producer_unknownTopic" + err := b.messageQueue.CreateTopic(context.Background(), unknownTopic, 1) + require.NoError(t, err) + require.NoError(t, b.messageQueue.DeleteTopics(context.Background(), unknownTopic)) + + _, err = b.messageQueue.Producer(unknownTopic) + require.ErrorIs(t, err, mqerr.ErrUnknownTopic) +} + +func (b *TestSuite) TestMQ_Consumer() { + t := b.T() + t.Parallel() + + unknownTopic, groupID := "consumer_unknownTopic", "c1" + err := b.messageQueue.CreateTopic(context.Background(), unknownTopic, 1) + require.NoError(t, err) + require.NoError(t, b.messageQueue.DeleteTopics(context.Background(), unknownTopic)) + + _, err = b.messageQueue.Consumer(unknownTopic, groupID) + require.ErrorIs(t, err, mqerr.ErrUnknownTopic) +} + func (b *TestSuite) TestMQ_Close() { t := b.T() t.Parallel() @@ -581,7 +619,7 @@ func (b *TestSuite) TestConsumer_ConsumeChan() { require.ErrorIs(t, err, context.Canceled) }) - t.Run("相通消费者组", func(t *testing.T) { + t.Run("相同消费者组", func(t *testing.T) { t.Parallel() t.Run("单分区_分区内顺序消费", func(t *testing.T) { @@ -736,7 +774,7 @@ func (b *TestSuite) TestConsumer_Consume() { require.ErrorIs(t, err, context.Canceled) }) - t.Run("相通消费者组", func(t *testing.T) { + t.Run("相同消费者组", func(t *testing.T) { t.Parallel() t.Run("单分区_分区内顺序消费", func(t *testing.T) { diff --git a/kafka/mq.go b/kafka/mq.go index f8d8c7b..3fe3575 100644 --- a/kafka/mq.go +++ b/kafka/mq.go @@ -91,6 +91,10 @@ func (m *MQ) CreateTopic(ctx context.Context, name string, partitions int) error return ctx.Err() } + if _, ok := m.topicConfigMapping[name]; ok { + return fmt.Errorf("kafka: %w", mqerr.ErrCreatedTopic) + } + m.topicConfigMapping[name] = kafkago.TopicConfig{Topic: name, NumPartitions: partitions, ReplicationFactor: m.replicationFactor} return m.controllerConn.CreateTopics(m.topicConfigMapping[name]) } @@ -108,6 +112,10 @@ func (m *MQ) DeleteTopics(ctx context.Context, topics ...string) error { return ctx.Err() } + for _, topic := range topics { + delete(m.topicConfigMapping, topic) + } + err := m.controllerConn.DeleteTopics(topics...) var val kafkago.Error if errors.As(err, &val) && val == kafkago.UnknownTopicOrPartition { @@ -124,6 +132,10 @@ func (m *MQ) Producer(topic string) (mq.Producer, error) { return nil, fmt.Errorf("kafka: %w", mqerr.ErrMQIsClosed) } + if _, ok := m.topicConfigMapping[topic]; !ok { + return nil, fmt.Errorf("kafka: %w", mqerr.ErrUnknownTopic) + } + balancer, _ := NewSpecifiedPartitionBalancer(&kafkago.Hash{}) p := NewProducer(m.address, topic, m.topicConfigMapping[topic].NumPartitions, balancer) m.producers = append(m.producers, p) @@ -138,6 +150,10 @@ func (m *MQ) Consumer(topic, groupID string) (mq.Consumer, error) { return nil, fmt.Errorf("kafka: %w", mqerr.ErrMQIsClosed) } + if _, ok := m.topicConfigMapping[topic]; !ok { + return nil, fmt.Errorf("kafka: %w", mqerr.ErrUnknownTopic) + } + c := NewConsumer(m.address, topic, groupID) m.consumers = append(m.consumers, c) @@ -157,6 +173,7 @@ func (m *MQ) Close() error { for _, c := range m.consumers { errorList = append(errorList, c.Close()) } + errorList = append(errorList, m.controllerConn.Close()) m.closeErr = multierr.Combine(errorList...) m.closed = true diff --git a/mqerr/error.go b/mqerr/error.go index b6f0f0d..39fb6d8 100644 --- a/mqerr/error.go +++ b/mqerr/error.go @@ -21,5 +21,7 @@ var ( ErrProducerIsClosed = errors.New("生产者已经关闭") ErrMQIsClosed = errors.New("mq已经关闭") ErrInvalidTopic = errors.New("topic非法") + ErrCreatedTopic = errors.New("topic已创建") ErrInvalidPartition = errors.New("partition非法") + ErrUnknownTopic = errors.New("未知topic") )