From de66a5900cac885ae8e842aed9b19af8d22226ad Mon Sep 17 00:00:00 2001
From: Timon Borter <timon.borter@gmx.ch>
Date: Mon, 16 Dec 2024 19:07:38 +0100
Subject: [PATCH] fix(#1281): introduce separate consumers per subscription

---
 .../kafka/endpoint/KafkaConsumer.java         |  65 +++++++++--
 .../kafka/endpoint/KafkaEndpoint.java         |   2 +-
 .../kafka/endpoint/KafkaConsumerTest.java     | 105 ++++++++++++++++--
 .../kafka/endpoint/KafkaEndpointTest.java     |   7 +-
 4 files changed, 158 insertions(+), 21 deletions(-)

diff --git a/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaConsumer.java b/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaConsumer.java
index b84ddd456e..e0afb0db9a 100644
--- a/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaConsumer.java
+++ b/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaConsumer.java
@@ -22,11 +22,13 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.time.Duration;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
+import java.util.concurrent.ConcurrentLinkedQueue;
 
+import static java.util.Objects.isNull;
+import static java.util.Objects.nonNull;
 import static java.util.UUID.randomUUID;
 import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG;
 import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_OFFSET_RESET_CONFIG;
@@ -44,28 +46,70 @@ public class KafkaConsumer extends AbstractSelectiveMessageConsumer {
     private static final Logger logger = LoggerFactory.getLogger(KafkaConsumer.class);
 
     private org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> consumer;
+    private final ConcurrentLinkedQueue<org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object>> managedConsumers = new ConcurrentLinkedQueue<>();
 
     /**
      * Default constructor using endpoint.
      */
     public KafkaConsumer(String name, KafkaEndpointConfiguration endpointConfiguration) {
         super(name, endpointConfiguration);
-        this.consumer = createConsumer();
     }
 
+    /**
+     * Initializes and provides a new {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance in a thread-safe manner.
+     * This method is the preferred way to obtain a consumer instance as it ensures proper lifecycle management and thread-safety.
+     * <p>
+     * The created consumer is automatically registered for lifecycle management and cleanup.
+     * Each call to this method creates a new consumer instance.
+     *
+     * @return a new thread-safe {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance
+     */
+    public org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> createManagedConsumer() {
+        if (nonNull(consumer)) {
+            return consumer;
+        }
+
+        var managedConsumer = createKafkaConsumer();
+        managedConsumers.add(managedConsumer);
+        return managedConsumer;
+    }
+
+    /**
+     * Returns the current {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance.
+     *
+     * @return the current {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance
+     * @deprecated {@link org.apache.kafka.clients.consumer.KafkaConsumer} is <b>not</b> thread-safe and manual consumer management is error-prone.
+     * Use {@link #createManagedConsumer()} instead to obtain properly managed consumer instances.
+     * This method will be removed in a future release.
+     */
+    @Deprecated(forRemoval = true)
     public org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> getConsumer() {
+        if (isNull(consumer)) {
+            consumer = createKafkaConsumer();
+        }
+
         return consumer;
     }
 
+    /**
+     * Sets the {@link org.apache.kafka.clients.consumer.KafkaConsumer} instance.
+     *
+     * @param consumer the KafkaConsumer to set
+     * @deprecated {@link org.apache.kafka.clients.consumer.KafkaConsumer} is <b>not</b> thread-safe and manual consumer management is error-prone.
+     * Use {@link #createManagedConsumer()} instead to obtain properly managed consumer instances.
+     * This method will be removed in a future release.
+     */
+    @Deprecated(forRemoval = true)
     public void setConsumer(org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> consumer) {
         this.consumer = consumer;
+        this.managedConsumers.add(consumer);
     }
 
     @Override
     public Message receive(TestContext testContext, long timeout) {
         logger.debug("Receiving single message");
         return KafkaMessageSingleConsumer.builder()
-                .consumer(consumer)
+                .consumer(createManagedConsumer())
                 .endpointConfiguration(getEndpointConfiguration())
                 .build()
                 .receive(testContext, timeout);
@@ -75,7 +119,7 @@ public Message receive(TestContext testContext, long timeout) {
     public Message receive(String selector, TestContext testContext, long timeout) {
         logger.debug("Receiving selected message: {}", selector);
         return KafkaMessageFilteringConsumer.builder()
-                .consumer(consumer)
+                .consumer(createManagedConsumer())
                 .endpointConfiguration(getEndpointConfiguration())
                 .build()
                 .receive(selector, testContext, timeout);
@@ -90,19 +134,20 @@ protected KafkaEndpointConfiguration getEndpointConfiguration() {
      * Stop message listener container.
      */
     public void stop() {
-        try {
-            if (consumer.subscription() != null && !consumer.subscription().isEmpty()) {
-                consumer.unsubscribe();
+        org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> consumerToDelete;
+        while (nonNull(consumerToDelete = managedConsumers.poll())) {
+            try {
+                consumerToDelete.unsubscribe();
+            } finally {
+                consumerToDelete.close();
             }
-        } finally {
-            consumer.close(Duration.ofMillis(10 * 1000L));
         }
     }
 
     /**
      * Create new Kafka consumer with given endpoint configuration.
      */
-    private org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> createConsumer() {
+    private org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> createKafkaConsumer() {
         Map<String, Object> consumerProps = new HashMap<>();
         consumerProps.put(CLIENT_ID_CONFIG, Optional.ofNullable(getEndpointConfiguration().getClientId()).orElseGet(() -> KAFKA_PREFIX + "consumer_" + randomUUID()));
         consumerProps.put(GROUP_ID_CONFIG, getEndpointConfiguration().getConsumerGroup());
diff --git a/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaEndpoint.java b/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaEndpoint.java
index c960e744d7..56118d2b9e 100644
--- a/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaEndpoint.java
+++ b/endpoints/citrus-kafka/src/main/java/org/citrusframework/kafka/endpoint/KafkaEndpoint.java
@@ -190,7 +190,7 @@ public SimpleKafkaEndpointBuilder topic(String topic) {
         }
 
         public KafkaEndpoint build() {
-            return KafkaEndpoint.newKafkaEndpoint(kafkaConsumer, kafkaProducer, randomConsumerGroup, server, timeout, topic);
+            return newKafkaEndpoint(kafkaConsumer, kafkaProducer, randomConsumerGroup, server, timeout, topic);
         }
     }
 }
diff --git a/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaConsumerTest.java b/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaConsumerTest.java
index c1efe6414a..cc910caa4a 100644
--- a/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaConsumerTest.java
+++ b/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaConsumerTest.java
@@ -36,10 +36,14 @@
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonMap;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.anyList;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertNotNull;
@@ -51,7 +55,7 @@ public class KafkaConsumerTest extends AbstractTestNGUnitTest {
     private final org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> kafkaConsumerMock = mock(KafkaConsumer.class);
 
     @Test
-    public void testReceiveMessage() {
+    public void receiveMessage() {
         String topic = "default";
 
         KafkaEndpoint endpoint = KafkaEndpoint.builder()
@@ -59,7 +63,7 @@ public void testReceiveMessage() {
                 .topic(topic)
                 .build();
 
-        TopicPartition partition = new TopicPartition(topic, 0);
+        var partition = new TopicPartition(topic, 0);
 
         reset(kafkaConsumerMock);
 
@@ -83,7 +87,7 @@ public void testReceiveMessage() {
     }
 
     @Test
-    public void testReceiveMessage_inRandomConsumerGroup() {
+    public void receiveMessage_inRandomConsumerGroup() {
         String topic = "default";
 
         KafkaEndpoint endpoint = KafkaEndpoint.builder()
@@ -91,7 +95,7 @@ public void testReceiveMessage_inRandomConsumerGroup() {
                 .topic(topic)
                 .build();
 
-        TopicPartition partition = new TopicPartition(topic, 0);
+        var partition = new TopicPartition(topic, 0);
 
         reset(kafkaConsumerMock);
 
@@ -115,7 +119,7 @@ public void testReceiveMessage_inRandomConsumerGroup() {
     }
 
     @Test
-    public void testReceiveMessageTimeout() {
+    public void receiveMessage_runIntoTimeout() {
         String topic = "test";
 
         KafkaEndpoint endpoint = KafkaEndpoint.builder()
@@ -140,7 +144,7 @@ public void testReceiveMessageTimeout() {
     }
 
     @Test
-    public void testWithCustomTimeout() {
+    public void receiveMessage_customTimeout_runIntoTimeout() {
         String topic = "timeout";
 
         KafkaEndpoint endpoint = KafkaEndpoint.builder()
@@ -149,7 +153,7 @@ public void testWithCustomTimeout() {
                 .topic(topic)
                 .build();
 
-        TopicPartition partition = new TopicPartition(topic, 0);
+        var partition = new TopicPartition(topic, 0);
 
         reset(kafkaConsumerMock);
         when(kafkaConsumerMock.subscription()).thenReturn(singleton(topic));
@@ -165,7 +169,7 @@ public void testWithCustomTimeout() {
     }
 
     @Test
-    public void testWithMessageHeaders() {
+    public void receiveMessage_withMessageHeaders() {
         String topic = "headers";
 
         KafkaEndpoint endpoint = KafkaEndpoint.builder()
@@ -174,7 +178,7 @@ public void testWithMessageHeaders() {
                 .topic(topic)
                 .build();
 
-        TopicPartition partition = new TopicPartition(topic, 0);
+        var partition = new TopicPartition(topic, 0);
 
         reset(kafkaConsumerMock);
         when(kafkaConsumerMock.subscription()).thenReturn(singleton(topic));
@@ -193,4 +197,87 @@ public void testWithMessageHeaders() {
         assertNotNull(receivedMessage.getHeader("Operation"));
         assertEquals(receivedMessage.getHeader("Operation"), "sayHello");
     }
+
+    @Test
+    public void getConsumer_returnsSetConsumer() {
+        var kafkaConsumerMock = mock(KafkaConsumer.class);
+        KafkaEndpoint endpoint = KafkaEndpoint.builder()
+                .kafkaConsumer(kafkaConsumerMock)
+                .build();
+
+        var result = endpoint.createConsumer().getConsumer();
+        assertThat(result)
+                .isEqualTo(kafkaConsumerMock);
+    }
+
+    @Test
+    public void getConsumer_createsConsumerIfNonSet() {
+        KafkaEndpoint endpoint = KafkaEndpoint.builder()
+                .kafkaConsumer(null) // null for explicity
+                .build();
+
+        var result = endpoint.createConsumer().getConsumer();
+        assertThat(result)
+                .isNotNull();
+    }
+
+    @Test
+    public void createManagedConsumer_createsDifferentManagedConsumers() {
+        KafkaEndpoint endpoint = KafkaEndpoint.builder()
+                .build();
+
+        var managedConsumer1 = endpoint.createConsumer().createManagedConsumer();
+        assertThat(managedConsumer1)
+                .isNotNull();
+
+        var managedConsumer2 = endpoint.createConsumer().createManagedConsumer();
+
+        assertThat(managedConsumer2)
+                .isNotNull()
+                .isNotEqualTo(managedConsumer1)
+                .isNotSameAs(managedConsumer1);
+    }
+
+    @Test
+    @SuppressWarnings({"unchecked"})
+    public void createManagedConsumer_returnsConsumerIfOneIsSet() {
+        var kafkaConsumerMock = mock(KafkaConsumer.class);
+        KafkaEndpoint endpoint = KafkaEndpoint.builder()
+                .kafkaConsumer(kafkaConsumerMock)
+                .build();
+
+        var result = endpoint.createConsumer().createManagedConsumer();
+        assertThat(result)
+                .isEqualTo(kafkaConsumerMock);
+    }
+
+    @Test
+    @SuppressWarnings({"unchecked"})
+    public void stop_unsubscribesAndClosesConsumer() {
+        var kafkaConsumerMock = mock(KafkaConsumer.class);
+        KafkaEndpoint endpoint = KafkaEndpoint.builder()
+                .kafkaConsumer(kafkaConsumerMock)
+                .build();
+
+        endpoint.createConsumer().stop();
+        verify(kafkaConsumerMock).unsubscribe();
+        verify(kafkaConsumerMock).close();
+    }
+
+    @Test
+    @SuppressWarnings({"unchecked"})
+    public void stop_closesConsumerEvenAfterUnsubscriptionError() {
+        var kafkaConsumerMock = mock(KafkaConsumer.class);
+        var unsubscribeException = new RuntimeException();
+        doThrow(unsubscribeException).when(kafkaConsumerMock).unsubscribe();
+
+        KafkaEndpoint endpoint = KafkaEndpoint.builder()
+                .kafkaConsumer(kafkaConsumerMock)
+                .build();
+
+        assertThatThrownBy(() -> endpoint.createConsumer().stop())
+                .isEqualTo(unsubscribeException);
+
+        verify(kafkaConsumerMock).close();
+    }
 }
diff --git a/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaEndpointTest.java b/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaEndpointTest.java
index dfbbd304a0..0882cd5e0e 100644
--- a/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaEndpointTest.java
+++ b/endpoints/citrus-kafka/src/test/java/org/citrusframework/kafka/endpoint/KafkaEndpointTest.java
@@ -122,9 +122,14 @@ public void newKafkaEndpoint_isAbleToCreateRandomConsumerGroup() {
                 .startsWith(KAFKA_PREFIX)
                 .hasSize(23)
                 .containsPattern(".*[a-z]{10}$")
+                // Make sure the random group id is propagated to new consumers
                 .satisfies(
-                        // Additionally make sure that gets passed downstream
                         groupId -> assertThat(fixture.createConsumer().getConsumer())
+                                .extracting("delegate")
+                                .extracting("groupId")
+                                .asInstanceOf(OPTIONAL)
+                                .hasValue(groupId),
+                        groupId -> assertThat(fixture.createConsumer().createManagedConsumer())
                                 .extracting("delegate")
                                 .extracting("groupId")
                                 .asInstanceOf(OPTIONAL)