Skip to content

Commit

Permalink
feat(#1281): introduce separate consumers per subscription
Browse files Browse the repository at this point in the history
The `org.apache.kafka.clients.consumer.KafkaConsumer<K,V>` is **not thread-safe**.
When using selective message consumption, it is recommended to configure `useThreadSafeConsumer` respectively `thread-safe-consumer` for the Kafka endpoint.
Otherwise, you will experience errors when executing tests in parallel.
  • Loading branch information
bbortt committed Jan 9, 2025
1 parent ee96a03 commit 8744ddd
Show file tree
Hide file tree
Showing 18 changed files with 417 additions and 216 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public abstract class AbstractMessageConsumer implements Consumer {
/**
* Default constructor using receive timeout setting.
*/
public AbstractMessageConsumer(String name, EndpointConfiguration endpointConfiguration) {
protected AbstractMessageConsumer(String name, EndpointConfiguration endpointConfiguration) {
this.name = name;
this.endpointConfiguration = endpointConfiguration;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public abstract class AbstractSelectiveMessageConsumer extends AbstractMessageCo
* @param name
* @param endpointConfiguration
*/
public AbstractSelectiveMessageConsumer(String name, EndpointConfiguration endpointConfiguration) {
protected AbstractSelectiveMessageConsumer(String name, EndpointConfiguration endpointConfiguration) {
super(name, endpointConfiguration);
this.endpointConfiguration = endpointConfiguration;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ protected void parseEndpointConfiguration(BeanDefinitionBuilder endpointConfigur
setPropertyValue(endpointConfiguration, element.getAttribute("key-deserializer"), "keyDeserializer");
setPropertyValue(endpointConfiguration, element.getAttribute("value-serializer"), "valueSerializer");
setPropertyValue(endpointConfiguration, element.getAttribute("value-deserializer"), "valueDeserializer");

setPropertyValue(endpointConfiguration, element.getAttribute("thread-safe-consumer"), "useThreadSafeConsumer");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,6 @@

package org.citrusframework.kafka.endpoint;

import org.citrusframework.context.TestContext;
import org.citrusframework.message.Message;
import org.citrusframework.messaging.AbstractSelectiveMessageConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static java.util.UUID.randomUUID;
import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG;
import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_OFFSET_RESET_CONFIG;
Expand All @@ -39,6 +28,16 @@
import static org.apache.kafka.clients.consumer.ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG;
import static org.citrusframework.kafka.message.KafkaMessageHeaders.KAFKA_PREFIX;

import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.citrusframework.context.TestContext;
import org.citrusframework.message.Message;
import org.citrusframework.messaging.AbstractSelectiveMessageConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KafkaConsumer extends AbstractSelectiveMessageConsumer {

private static final Logger logger = LoggerFactory.getLogger(KafkaConsumer.class);
Expand All @@ -65,20 +64,20 @@ public void setConsumer(org.apache.kafka.clients.consumer.KafkaConsumer<Object,
public Message receive(TestContext testContext, long timeout) {
logger.debug("Receiving single message");
return KafkaMessageSingleConsumer.builder()
.consumer(consumer)
.endpointConfiguration(getEndpointConfiguration())
.build()
.receive(testContext, timeout);
.consumer(consumer)
.endpointConfiguration(getEndpointConfiguration())
.build()
.receive(testContext, timeout);
}

@Override
public Message receive(String selector, TestContext testContext, long timeout) {
logger.debug("Receiving selected message: {}", selector);
return KafkaMessageFilteringConsumer.builder()
.consumer(consumer)
.endpointConfiguration(getEndpointConfiguration())
.build()
.receive(selector, testContext, timeout);
.consumer(consumer)
.endpointConfiguration(getEndpointConfiguration())
.build()
.receive(selector, testContext, timeout);
}

@Override
Expand All @@ -95,7 +94,7 @@ public void stop() {
consumer.unsubscribe();
}
} finally {
consumer.close(Duration.ofMillis(10 * 1000L));
consumer.close(Duration.ofSeconds(10));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@

package org.citrusframework.kafka.endpoint;

import jakarta.annotation.Nullable;
import org.apache.commons.lang3.RandomStringUtils;
import org.citrusframework.actions.ReceiveMessageAction;
import org.citrusframework.common.ShutdownPhase;
import org.citrusframework.endpoint.AbstractEndpoint;

import java.time.Duration;

import static java.lang.Boolean.TRUE;
import static java.util.Objects.isNull;
import static java.util.Objects.nonNull;
import static org.citrusframework.actions.ReceiveMessageAction.Builder.receive;
import static org.citrusframework.kafka.endpoint.selector.KafkaMessageByHeaderSelector.kafkaHeaderEquals;
import static org.citrusframework.kafka.message.KafkaMessageHeaders.KAFKA_PREFIX;
import static org.citrusframework.util.StringUtils.hasText;

import jakarta.annotation.Nullable;
import java.time.Duration;
import org.apache.commons.lang3.RandomStringUtils;
import org.citrusframework.actions.ReceiveMessageAction;
import org.citrusframework.common.ShutdownPhase;
import org.citrusframework.endpoint.AbstractEndpoint;

/**
* Kafka message endpoint capable of sending/receiving messages from Kafka message destination.
* Either uses a Kafka connection factory or a Spring Kafka template to connect with Kafka
Expand All @@ -46,6 +46,8 @@ public class KafkaEndpoint extends AbstractEndpoint implements ShutdownPhase {
private @Nullable KafkaProducer kafkaProducer;
private @Nullable KafkaConsumer kafkaConsumer;

private final ThreadLocal<KafkaConsumer> threadLocalKafkaConsumer;

public static SimpleKafkaEndpointBuilder builder() {
return new SimpleKafkaEndpointBuilder();
}
Expand All @@ -54,23 +56,24 @@ public static SimpleKafkaEndpointBuilder builder() {
* Default constructor initializing endpoint configuration.
*/
public KafkaEndpoint() {
super(new KafkaEndpointConfiguration());
this(new KafkaEndpointConfiguration());
}

/**
* Constructor with endpoint configuration.
*/
public KafkaEndpoint(KafkaEndpointConfiguration endpointConfiguration) {
super(endpointConfiguration);

threadLocalKafkaConsumer = ThreadLocal.withInitial(() -> new KafkaConsumer(getConsumerName(), getEndpointConfiguration()));
}

static KafkaEndpoint newKafkaEndpoint(
@Nullable org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> kafkaConsumer,
@Nullable org.apache.kafka.clients.producer.KafkaProducer<Object, Object> kafkaProducer,
@Nullable Boolean randomConsumerGroup,
@Nullable String server,
@Nullable Long timeout,
@Nullable String topic
@Nullable Boolean randomConsumerGroup,
@Nullable String server,
@Nullable Long timeout,
@Nullable String topic,
boolean useThreadSafeConsumer
) {
var kafkaEndpoint = new KafkaEndpoint();

Expand All @@ -88,6 +91,29 @@ static KafkaEndpoint newKafkaEndpoint(
kafkaEndpoint.getEndpointConfiguration().setTopic(topic);
}

kafkaEndpoint.getEndpointConfiguration().setUseThreadSafeConsumer(useThreadSafeConsumer);

return kafkaEndpoint;
}

/**
* @deprecated {@link org.apache.kafka.clients.consumer.KafkaConsumer} is <b>not</b> thread-safe
* and manual consumer management is error-prone. Use
* {@link #newKafkaEndpoint(Boolean, String, Long, String, boolean)} instead to obtain properly
* managed consumer instances. This method will be removed in a future release.
*/
@Deprecated(forRemoval = true)
static KafkaEndpoint newKafkaEndpoint(
@Nullable org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> kafkaConsumer,
@Nullable org.apache.kafka.clients.producer.KafkaProducer<Object, Object> kafkaProducer,
@Nullable Boolean randomConsumerGroup,
@Nullable String server,
@Nullable Long timeout,
@Nullable String topic,
boolean useThreadSafeConsumer
) {
var kafkaEndpoint = newKafkaEndpoint(randomConsumerGroup, server, timeout, topic, useThreadSafeConsumer);

// Make sure these come at the end, so endpoint configuration is already initialized
if (nonNull(kafkaConsumer)) {
kafkaEndpoint.createConsumer().setConsumer(kafkaConsumer);
Expand All @@ -99,19 +125,11 @@ static KafkaEndpoint newKafkaEndpoint(
return kafkaEndpoint;
}

@Nullable
KafkaProducer getKafkaProducer() {
return kafkaProducer;
}

@Nullable
KafkaConsumer getKafkaConsumer() {
return kafkaConsumer;
}

@Override
public KafkaConsumer createConsumer() {
if (kafkaConsumer == null) {
if (getEndpointConfiguration().useThreadSafeConsumer()) {
return threadLocalKafkaConsumer.get();
} else if (isNull(kafkaConsumer)) {
kafkaConsumer = new KafkaConsumer(getConsumerName(), getEndpointConfiguration());
}

Expand All @@ -134,20 +152,24 @@ public KafkaEndpointConfiguration getEndpointConfiguration() {

@Override
public void destroy() {
if (kafkaConsumer != null) {
if (getEndpointConfiguration().useThreadSafeConsumer()) {
threadLocalKafkaConsumer.get()
.stop();
threadLocalKafkaConsumer.remove();
} else if (nonNull(kafkaConsumer)) {
kafkaConsumer.stop();
}
}

public ReceiveMessageAction.ReceiveMessageActionBuilderSupport findKafkaEventHeaderEquals(Duration lookbackWindow, String key, String value) {
return receive(this)
.selector(
KafkaMessageFilter.kafkaMessageFilter()
.eventLookbackWindow(lookbackWindow)
.kafkaMessageSelector(kafkaHeaderEquals(key, value))
.build()
)
.getMessageBuilderSupport();
.selector(
KafkaMessageFilter.kafkaMessageFilter()
.eventLookbackWindow(lookbackWindow)
.kafkaMessageSelector(kafkaHeaderEquals(key, value))
.build()
)
.getMessageBuilderSupport();
}

public static class SimpleKafkaEndpointBuilder {
Expand All @@ -158,7 +180,13 @@ public static class SimpleKafkaEndpointBuilder {
private String server;
private Long timeout;
private String topic;
private boolean useThreadSafeConsumer = false;

/**
* @deprecated {@link org.apache.kafka.clients.consumer.KafkaConsumer} is <b>not</b> thread-safe and manual consumer management is error-prone.
* This method will be removed in a future release.
*/
@Deprecated(forRemoval = true)
public SimpleKafkaEndpointBuilder kafkaConsumer(org.apache.kafka.clients.consumer.KafkaConsumer<Object, Object> kafkaConsumer) {
this.kafkaConsumer = kafkaConsumer;
return this;
Expand Down Expand Up @@ -189,8 +217,13 @@ public SimpleKafkaEndpointBuilder topic(String topic) {
return this;
}

public SimpleKafkaEndpointBuilder useThreadSafeConsumer() {
this.useThreadSafeConsumer = true;
return this;
}

public KafkaEndpoint build() {
return KafkaEndpoint.newKafkaEndpoint(kafkaConsumer, kafkaProducer, randomConsumerGroup, server, timeout, topic);
return newKafkaEndpoint(kafkaConsumer, kafkaProducer, randomConsumerGroup, server, timeout, topic, useThreadSafeConsumer);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ public class KafkaEndpointConfiguration extends AbstractPollableEndpointConfigur
*/
private int partition = 0;

private boolean useThreadSafeConsumer;

public String getClientId() {
return clientId;
}
Expand Down Expand Up @@ -221,4 +223,12 @@ public int getPartition() {
public void setPartition(int partition) {
this.partition = partition;
}

public boolean useThreadSafeConsumer() {
return useThreadSafeConsumer;
}

public void setUseThreadSafeConsumer(boolean useThreadSafeConsumer) {
this.useThreadSafeConsumer = useThreadSafeConsumer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ public Message receive(String selector, TestContext testContext, long timeout) {
getEndpointConfiguration(),
testContext);

logger.info("Received Kafka message on topic: '{}", topic);
if (logger.isDebugEnabled()) {
logger.info("Received Kafka message on topic '{}': {}", topic, received);
} else {
logger.info("Received Kafka message on topic '{}'", topic);
}

return received;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ public Message receive(TestContext testContext, long timeout) {

consumer.commitSync(Duration.ofMillis(getEndpointConfiguration().getTimeout()));

logger.info("Received Kafka message on topic: '{}", topic);
if (logger.isDebugEnabled()) {
logger.info("Received Kafka message on topic '{}': {}", topic, received);
} else {
logger.info("Received Kafka message on topic '{}'", topic);
}

return received;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
<xs:attribute name="value-deserializer" type="xs:string"/>
<xs:attribute name="producer-properties" type="xs:string"/>
<xs:attribute name="consumer-properties" type="xs:string"/>
<xs:attribute name="thread-safe-consumer" type="xs:boolean" default="false"/>
</xs:complexType>
</xs:element>

</xs:schema>
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public void testKafkaEndpointParser() {
assertEquals(kafkaEndpoint.getEndpointConfiguration().getValueSerializer(), StringSerializer.class);
assertEquals(kafkaEndpoint.getEndpointConfiguration().getKeyDeserializer(), StringDeserializer.class);
assertEquals(kafkaEndpoint.getEndpointConfiguration().getValueDeserializer(), StringDeserializer.class);
assertEquals(kafkaEndpoint.getEndpointConfiguration().useThreadSafeConsumer(), false);

// 2nd message receiver
kafkaEndpoint = endpoints.get("kafkaEndpoint2");
Expand All @@ -90,6 +91,7 @@ public void testKafkaEndpointParser() {
assertEquals(kafkaEndpoint.getEndpointConfiguration().getValueSerializer(), ByteArraySerializer.class);
assertEquals(kafkaEndpoint.getEndpointConfiguration().getKeyDeserializer(), IntegerDeserializer.class);
assertEquals(kafkaEndpoint.getEndpointConfiguration().getValueDeserializer(), ByteArrayDeserializer.class);
assertEquals(kafkaEndpoint.getEndpointConfiguration().useThreadSafeConsumer(), false);

// 3rd message receiver
kafkaEndpoint = endpoints.get("kafkaEndpoint3");
Expand All @@ -102,12 +104,14 @@ public void testKafkaEndpointParser() {
assertEquals(kafkaEndpoint.getEndpointConfiguration().getConsumerProperties().get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG), true);
assertEquals(kafkaEndpoint.getEndpointConfiguration().getProducerProperties().size(), 1);
assertEquals(kafkaEndpoint.getEndpointConfiguration().getProducerProperties().get(ProducerConfig.MAX_REQUEST_SIZE_CONFIG), 1024);
assertEquals(kafkaEndpoint.getEndpointConfiguration().useThreadSafeConsumer(), false);

// 4th message receiver
kafkaEndpoint = endpoints.get("kafkaEndpoint4");
assertThat(kafkaEndpoint.getEndpointConfiguration().getConsumerGroup())
.startsWith(KAFKA_PREFIX)
.hasSize(23)
.containsPattern(".*[a-z]{10}$");
assertEquals(kafkaEndpoint.getEndpointConfiguration().useThreadSafeConsumer(), true);
}
}
Loading

0 comments on commit 8744ddd

Please sign in to comment.