From 9af4330c6d59d103a44dfb3ee910e7bafb816915 Mon Sep 17 00:00:00 2001 From: Alex Damian Date: Tue, 16 Oct 2018 11:57:11 -0400 Subject: [PATCH] Allocators (#118) * Added allocator support for consumers and buffered producer * Changed MessageList back to std::vector for consistency with the allocator API --- include/cppkafka/consumer.h | 39 ++++- include/cppkafka/queue.h | 42 ++++- include/cppkafka/utils/buffered_producer.h | 158 +++++++++--------- include/cppkafka/utils/poll_interface.h | 4 +- .../cppkafka/utils/roundrobin_poll_strategy.h | 68 +++++++- src/consumer.cpp | 21 +-- src/queue.cpp | 23 +-- src/utils/roundrobin_poll_strategy.cpp | 42 +---- tests/test_utils.h | 6 +- tests/test_utils_impl.h | 9 +- 10 files changed, 247 insertions(+), 165 deletions(-) diff --git a/include/cppkafka/consumer.h b/include/cppkafka/consumer.h index 332b9422..ecb36e3f 100644 --- a/include/cppkafka/consumer.h +++ b/include/cppkafka/consumer.h @@ -379,10 +379,14 @@ class CPPKAFKA_API Consumer : public KafkaHandleBase { * This can return one or more messages * * \param max_batch_size The maximum amount of messages expected + * \param alloc The optionally supplied allocator for allocating messages * * \return A list of messages */ - MessageList poll_batch(size_t max_batch_size); + template + std::vector poll_batch(size_t max_batch_size, + const Allocator& alloc); + std::vector poll_batch(size_t max_batch_size); /** * \brief Polls for a batch of messages @@ -391,10 +395,16 @@ class CPPKAFKA_API Consumer : public KafkaHandleBase { * * \param max_batch_size The maximum amount of messages expected * \param timeout The timeout for this operation + * \param alloc The optionally supplied allocator for allocating messages * * \return A list of messages */ - MessageList poll_batch(size_t max_batch_size, std::chrono::milliseconds timeout); + template + std::vector poll_batch(size_t max_batch_size, + std::chrono::milliseconds timeout, + const Allocator& alloc); + std::vector poll_batch(size_t max_batch_size, + std::chrono::milliseconds timeout); /** * \brief Get the global event queue servicing this consumer corresponding to @@ -430,6 +440,7 @@ class CPPKAFKA_API Consumer : public KafkaHandleBase { private: static void rebalance_proxy(rd_kafka_t *handle, rd_kafka_resp_err_t error, rd_kafka_topic_partition_list_t *partitions, void *opaque); + static Queue get_queue(rd_kafka_queue_t* handle); void close(); void commit(const Message& msg, bool async); void commit(const TopicPartitionList* topic_partitions, bool async); @@ -440,6 +451,30 @@ class CPPKAFKA_API Consumer : public KafkaHandleBase { RebalanceErrorCallback rebalance_error_callback_; }; +// Implementations +template +std::vector Consumer::poll_batch(size_t max_batch_size, + const Allocator& alloc) { + return poll_batch(max_batch_size, get_timeout(), alloc); +} + +template +std::vector Consumer::poll_batch(size_t max_batch_size, + std::chrono::milliseconds timeout, + const Allocator& alloc) { + std::vector raw_messages(max_batch_size); + // Note that this will leak the queue when using rdkafka < 0.11.5 (see get_queue comment) + Queue queue(get_queue(rd_kafka_queue_get_consumer(get_handle()))); + ssize_t result = rd_kafka_consume_batch_queue(queue.get_handle() , timeout.count(), raw_messages.data(), + raw_messages.size()); + if (result == -1) { + check_error(rd_kafka_last_error()); + // on the off-chance that check_error() does not throw an error + return std::vector(alloc); + } + return std::vector(raw_messages.begin(), raw_messages.begin() + result, alloc); +} + } // cppkafka #endif // CPP_KAFKA_CONSUMER_H diff --git a/include/cppkafka/queue.h b/include/cppkafka/queue.h index d7bc5028..3791e0d7 100644 --- a/include/cppkafka/queue.h +++ b/include/cppkafka/queue.h @@ -138,9 +138,14 @@ class CPPKAFKA_API Queue { * * \param max_batch_size The max number of messages to consume if available * + * \param alloc The optionally supplied allocator for the message list + * * \return A list of messages. Could be empty if there's nothing to consume */ - MessageList consume_batch(size_t max_batch_size) const; + template + std::vector consume_batch(size_t max_batch_size, + const Allocator& alloc) const; + std::vector consume_batch(size_t max_batch_size) const; /** * \brief Consumes a batch of messages from this queue @@ -151,9 +156,16 @@ class CPPKAFKA_API Queue { * * \param timeout The timeout to be used on this call * + * \param alloc The optionally supplied allocator for the message list + * * \return A list of messages. Could be empty if there's nothing to consume */ - MessageList consume_batch(size_t max_batch_size, std::chrono::milliseconds timeout) const; + template + std::vector consume_batch(size_t max_batch_size, + std::chrono::milliseconds timeout, + const Allocator& alloc) const; + std::vector consume_batch(size_t max_batch_size, + std::chrono::milliseconds timeout) const; /** * Indicates whether this queue is valid (not null) @@ -178,6 +190,32 @@ class CPPKAFKA_API Queue { using QueueList = std::vector; +template +std::vector Queue::consume_batch(size_t max_batch_size, + const Allocator& alloc) const { + return consume_batch(max_batch_size, timeout_ms_, alloc); +} + +template +std::vector Queue::consume_batch(size_t max_batch_size, + std::chrono::milliseconds timeout, + const Allocator& alloc) const { + std::vector raw_messages(max_batch_size); + ssize_t result = rd_kafka_consume_batch_queue(handle_.get(), + static_cast(timeout.count()), + raw_messages.data(), + raw_messages.size()); + if (result == -1) { + rd_kafka_resp_err_t error = rd_kafka_last_error(); + if (error != RD_KAFKA_RESP_ERR_NO_ERROR) { + throw QueueException(error); + } + return std::vector(alloc); + } + // Build message list + return std::vector(raw_messages.begin(), raw_messages.begin() + result, alloc); +} + } // cppkafka #endif //CPPKAFKA_QUEUE_H diff --git a/include/cppkafka/utils/buffered_producer.h b/include/cppkafka/utils/buffered_producer.h index cfba8717..fc8cadc0 100644 --- a/include/cppkafka/utils/buffered_producer.h +++ b/include/cppkafka/utils/buffered_producer.h @@ -83,7 +83,8 @@ namespace cppkafka { * the messages *after* the ProduceSuccessCallback has reported a successful delivery to avoid memory * corruptions. */ -template +template >> class CPPKAFKA_API BufferedProducer { public: enum class FlushMethod { Sync, ///< Empty the buffer and wait for acks from the broker @@ -92,6 +93,7 @@ class CPPKAFKA_API BufferedProducer { * Concrete builder */ using Builder = ConcreteMessageBuilder; + using QueueType = std::deque; /** * Callback to indicate a message was delivered to the broker @@ -115,8 +117,9 @@ class CPPKAFKA_API BufferedProducer { * \brief Constructs a buffered producer using the provided configuration * * \param config The configuration to be used on the actual Producer object + * \param alloc The optionally supplied allocator for the internal message buffer */ - BufferedProducer(Configuration config); + BufferedProducer(Configuration config, const Allocator& alloc = Allocator()); /** * \brief Adds a message to the producer's buffer. @@ -390,7 +393,6 @@ class CPPKAFKA_API BufferedProducer { #endif private: - using QueueType = std::deque; enum class MessagePriority { Low, High }; enum class SenderType { Sync, Async }; @@ -466,28 +468,30 @@ Producer::PayloadPolicy get_default_payload_policy() { return Producer::PayloadPolicy::PASSTHROUGH_PAYLOAD; } -template -BufferedProducer::BufferedProducer(Configuration config) -: producer_(prepare_configuration(std::move(config))) { +template +BufferedProducer::BufferedProducer(Configuration config, + const Allocator& alloc) +: producer_(prepare_configuration(std::move(config))), + messages_(alloc) { producer_.set_payload_policy(get_default_payload_policy()); #ifdef KAFKA_TEST_INSTANCE test_params_ = nullptr; #endif } -template -void BufferedProducer::add_message(const MessageBuilder& builder) { +template +void BufferedProducer::add_message(const MessageBuilder& builder) { add_message(Builder(builder)); //make ConcreteBuilder } -template -void BufferedProducer::add_message(Builder builder) { +template +void BufferedProducer::add_message(Builder builder) { add_tracker(SenderType::Async, builder); do_add_message(move(builder), MessagePriority::Low, true); } -template -void BufferedProducer::produce(const MessageBuilder& builder) { +template +void BufferedProducer::produce(const MessageBuilder& builder) { if (has_internal_data_) { MessageBuilder builder_clone(builder.clone()); add_tracker(SenderType::Async, builder_clone); @@ -498,8 +502,8 @@ void BufferedProducer::produce(const MessageBuilder& builder) { } } -template -void BufferedProducer::sync_produce(const MessageBuilder& builder) { +template +void BufferedProducer::sync_produce(const MessageBuilder& builder) { if (has_internal_data_) { MessageBuilder builder_clone(builder.clone()); TrackerPtr tracker = add_tracker(SenderType::Sync, builder_clone); @@ -519,13 +523,13 @@ void BufferedProducer::sync_produce(const MessageBuilder& builder) { } } -template -void BufferedProducer::produce(const Message& message) { +template +void BufferedProducer::produce(const Message& message) { async_produce(MessageBuilder(message), true); } -template -void BufferedProducer::async_flush() { +template +void BufferedProducer::async_flush() { CounterGuard counter_guard(flushes_in_progress_); QueueType flush_queue; // flush from temporary queue { @@ -538,8 +542,8 @@ void BufferedProducer::async_flush() { } } -template -void BufferedProducer::flush(bool preserve_order) { +template +void BufferedProducer::flush(bool preserve_order) { if (preserve_order) { CounterGuard counter_guard(flushes_in_progress_); QueueType flush_queue; // flush from temporary queue @@ -558,8 +562,8 @@ void BufferedProducer::flush(bool preserve_order) { } } -template -bool BufferedProducer::flush(std::chrono::milliseconds timeout, +template +bool BufferedProducer::flush(std::chrono::milliseconds timeout, bool preserve_order) { if (preserve_order) { CounterGuard counter_guard(flushes_in_progress_); @@ -582,8 +586,8 @@ bool BufferedProducer::flush(std::chrono::milliseconds timeout, } } -template -void BufferedProducer::wait_for_acks() { +template +void BufferedProducer::wait_for_acks() { while (pending_acks_ > 0) { try { producer_.flush(); @@ -600,8 +604,8 @@ void BufferedProducer::wait_for_acks() { } } -template -bool BufferedProducer::wait_for_acks(std::chrono::milliseconds timeout) { +template +bool BufferedProducer::wait_for_acks(std::chrono::milliseconds timeout) { auto remaining = timeout; auto start_time = std::chrono::high_resolution_clock::now(); while ((pending_acks_ > 0) && (remaining.count() > 0)) { @@ -625,47 +629,47 @@ bool BufferedProducer::wait_for_acks(std::chrono::milliseconds timeo return (pending_acks_ == 0); } -template -void BufferedProducer::clear() { +template +void BufferedProducer::clear() { std::lock_guard lock(mutex_); QueueType tmp; std::swap(tmp, messages_); } -template -size_t BufferedProducer::get_buffer_size() const { +template +size_t BufferedProducer::get_buffer_size() const { return messages_.size(); } -template -void BufferedProducer::set_max_buffer_size(ssize_t max_buffer_size) { +template +void BufferedProducer::set_max_buffer_size(ssize_t max_buffer_size) { if (max_buffer_size < -1) { throw Exception("Invalid buffer size."); } max_buffer_size_ = max_buffer_size; } -template -ssize_t BufferedProducer::get_max_buffer_size() const { +template +ssize_t BufferedProducer::get_max_buffer_size() const { return max_buffer_size_; } -template -void BufferedProducer::set_flush_method(FlushMethod method) { +template +void BufferedProducer::set_flush_method(FlushMethod method) { flush_method_ = method; } -template -typename BufferedProducer::FlushMethod -BufferedProducer::get_flush_method() const { +template +typename BufferedProducer::FlushMethod +BufferedProducer::get_flush_method() const { return flush_method_; } -template +template template -void BufferedProducer::do_add_message(BuilderType&& builder, - MessagePriority priority, - bool do_flush) { +void BufferedProducer::do_add_message(BuilderType&& builder, + MessagePriority priority, + bool do_flush) { { std::lock_guard lock(mutex_); if (priority == MessagePriority::High) { @@ -685,73 +689,73 @@ void BufferedProducer::do_add_message(BuilderType&& builder, } } -template -Producer& BufferedProducer::get_producer() { +template +Producer& BufferedProducer::get_producer() { return producer_; } -template -const Producer& BufferedProducer::get_producer() const { +template +const Producer& BufferedProducer::get_producer() const { return producer_; } -template -size_t BufferedProducer::get_pending_acks() const { +template +size_t BufferedProducer::get_pending_acks() const { return pending_acks_; } -template -size_t BufferedProducer::get_total_messages_produced() const { +template +size_t BufferedProducer::get_total_messages_produced() const { return total_messages_produced_; } -template -size_t BufferedProducer::get_total_messages_dropped() const { +template +size_t BufferedProducer::get_total_messages_dropped() const { return total_messages_dropped_; } -template -size_t BufferedProducer::get_flushes_in_progress() const { +template +size_t BufferedProducer::get_flushes_in_progress() const { return flushes_in_progress_; } -template -void BufferedProducer::set_max_number_retries(size_t max_number_retries) { +template +void BufferedProducer::set_max_number_retries(size_t max_number_retries) { if (!has_internal_data_ && (max_number_retries > 0)) { has_internal_data_ = true; //enable once } max_number_retries_ = max_number_retries; } -template -size_t BufferedProducer::get_max_number_retries() const { +template +size_t BufferedProducer::get_max_number_retries() const { return max_number_retries_; } -template -typename BufferedProducer::Builder -BufferedProducer::make_builder(std::string topic) { +template +typename BufferedProducer::Builder +BufferedProducer::make_builder(std::string topic) { return Builder(std::move(topic)); } -template -void BufferedProducer::set_produce_failure_callback(ProduceFailureCallback callback) { +template +void BufferedProducer::set_produce_failure_callback(ProduceFailureCallback callback) { produce_failure_callback_ = std::move(callback); } -template -void BufferedProducer::set_produce_success_callback(ProduceSuccessCallback callback) { +template +void BufferedProducer::set_produce_success_callback(ProduceSuccessCallback callback) { produce_success_callback_ = std::move(callback); } -template -void BufferedProducer::set_flush_failure_callback(FlushFailureCallback callback) { +template +void BufferedProducer::set_flush_failure_callback(FlushFailureCallback callback) { flush_failure_callback_ = std::move(callback); } -template +template template -void BufferedProducer::produce_message(BuilderType&& builder) { +void BufferedProducer::produce_message(BuilderType&& builder) { using builder_type = typename std::decay::type; while (true) { try { @@ -774,9 +778,9 @@ void BufferedProducer::produce_message(BuilderType&& builder) { } } -template +template template -void BufferedProducer::async_produce(BuilderType&& builder, bool throw_on_error) { +void BufferedProducer::async_produce(BuilderType&& builder, bool throw_on_error) { try { TestParameters* test_params = get_test_parameters(); if (test_params && test_params->force_produce_error_) { @@ -802,16 +806,16 @@ void BufferedProducer::async_produce(BuilderType&& builder, bool thr } } -template -Configuration BufferedProducer::prepare_configuration(Configuration config) { +template +Configuration BufferedProducer::prepare_configuration(Configuration config) { using std::placeholders::_2; - auto callback = std::bind(&BufferedProducer::on_delivery_report, this, _2); + auto callback = std::bind(&BufferedProducer::on_delivery_report, this, _2); config.set_delivery_report_callback(std::move(callback)); return config; } -template -void BufferedProducer::on_delivery_report(const Message& message) { +template +void BufferedProducer::on_delivery_report(const Message& message) { //Get tracker data TestParameters* test_params = get_test_parameters(); TrackerPtr tracker = has_internal_data_ ? diff --git a/include/cppkafka/utils/poll_interface.h b/include/cppkafka/utils/poll_interface.h index af93e3f3..ac8042f6 100644 --- a/include/cppkafka/utils/poll_interface.h +++ b/include/cppkafka/utils/poll_interface.h @@ -108,7 +108,7 @@ struct PollInterface { * otherwise the broker will think this consumer is down and will trigger a rebalance * (if using dynamic subscription) */ - virtual MessageList poll_batch(size_t max_batch_size) = 0; + virtual std::vector poll_batch(size_t max_batch_size) = 0; /** * \brief Polls all assigned partitions for a batch of new messages in round-robin fashion @@ -122,7 +122,7 @@ struct PollInterface { * * \return A list of messages */ - virtual MessageList poll_batch(size_t max_batch_size, std::chrono::milliseconds timeout) = 0; + virtual std::vector poll_batch(size_t max_batch_size, std::chrono::milliseconds timeout) = 0; }; } //cppkafka diff --git a/include/cppkafka/utils/roundrobin_poll_strategy.h b/include/cppkafka/utils/roundrobin_poll_strategy.h index bb91e054..6ad89714 100644 --- a/include/cppkafka/utils/roundrobin_poll_strategy.h +++ b/include/cppkafka/utils/roundrobin_poll_strategy.h @@ -102,14 +102,21 @@ class RoundRobinPollStrategy : public PollStrategyBase { /** * \sa PollInterface::poll_batch */ - MessageList poll_batch(size_t max_batch_size) override; + template + std::vector poll_batch(size_t max_batch_size, + const Allocator& alloc); + std::vector poll_batch(size_t max_batch_size) override; /** * \sa PollInterface::poll_batch */ - MessageList poll_batch(size_t max_batch_size, - std::chrono::milliseconds timeout) override; - + template + std::vector poll_batch(size_t max_batch_size, + std::chrono::milliseconds timeout, + const Allocator& alloc); + std::vector poll_batch(size_t max_batch_size, + std::chrono::milliseconds timeout) override; + protected: /** * \sa PollStrategyBase::reset_state @@ -119,10 +126,12 @@ class RoundRobinPollStrategy : public PollStrategyBase { QueueData& get_next_queue(); private: + template void consume_batch(Queue& queue, - MessageList& messages, + std::vector& messages, ssize_t& count, - std::chrono::milliseconds timeout); + std::chrono::milliseconds timeout, + const Allocator& alloc); void restore_forwarding(); @@ -130,6 +139,53 @@ class RoundRobinPollStrategy : public PollStrategyBase { QueueMap::iterator queue_iter_; }; +// Implementations +template +std::vector RoundRobinPollStrategy::poll_batch(size_t max_batch_size, + const Allocator& alloc) { + return poll_batch(max_batch_size, get_consumer().get_timeout(), alloc); +} + +template +std::vector RoundRobinPollStrategy::poll_batch(size_t max_batch_size, + std::chrono::milliseconds timeout, + const Allocator& alloc) { + std::vector messages(alloc); + ssize_t count = max_batch_size; + + // batch from the group event queue first (non-blocking) + consume_batch(get_consumer_queue().queue, messages, count, std::chrono::milliseconds(0), alloc); + size_t num_queues = get_partition_queues().size(); + while ((count > 0) && (num_queues--)) { + // batch from the next partition (non-blocking) + consume_batch(get_next_queue().queue, messages, count, std::chrono::milliseconds(0), alloc); + } + // we still have space left in the buffer + if (count > 0) { + // wait on the event queue until timeout + consume_batch(get_consumer_queue().queue, messages, count, timeout, alloc); + } + return messages; +} + +template +void RoundRobinPollStrategy::consume_batch(Queue& queue, + std::vector& messages, + ssize_t& count, + std::chrono::milliseconds timeout, + const Allocator& alloc) { + std::vector queue_messages = queue.consume_batch(count, timeout, alloc); + if (queue_messages.empty()) { + return; + } + // concatenate both lists + messages.insert(messages.end(), + make_move_iterator(queue_messages.begin()), + make_move_iterator(queue_messages.end())); + // reduce total batch count + count -= queue_messages.size(); +} + } //cppkafka #endif //CPPKAFKA_ROUNDROBIN_POLL_STRATEGY_H diff --git a/src/consumer.cpp b/src/consumer.cpp index 1c2c6b80..89e1a474 100644 --- a/src/consumer.cpp +++ b/src/consumer.cpp @@ -44,12 +44,13 @@ using std::ostringstream; using std::chrono::milliseconds; using std::toupper; using std::equal; +using std::allocator; namespace cppkafka { // See: https://github.com/edenhill/librdkafka/issues/1792 const int rd_kafka_queue_refcount_bug_version = 0x000b0500; -Queue get_queue(rd_kafka_queue_t* handle) { +Queue Consumer::get_queue(rd_kafka_queue_t* handle) { if (rd_kafka_version() <= rd_kafka_queue_refcount_bug_version) { return Queue::make_non_owning(handle); } @@ -255,22 +256,12 @@ Message Consumer::poll(milliseconds timeout) { return rd_kafka_consumer_poll(get_handle(), static_cast(timeout.count())); } -MessageList Consumer::poll_batch(size_t max_batch_size) { - return poll_batch(max_batch_size, get_timeout()); +std::vector Consumer::poll_batch(size_t max_batch_size) { + return poll_batch(max_batch_size, get_timeout(), allocator()); } -MessageList Consumer::poll_batch(size_t max_batch_size, milliseconds timeout) { - vector raw_messages(max_batch_size); - // Note that this will leak the queue when using rdkafka < 0.11.5 (see get_queue comment) - Queue queue(get_queue(rd_kafka_queue_get_consumer(get_handle()))); - ssize_t result = rd_kafka_consume_batch_queue(queue.get_handle() , timeout.count(), raw_messages.data(), - raw_messages.size()); - if (result == -1) { - check_error(rd_kafka_last_error()); - // on the off-chance that check_error() does not throw an error - return MessageList(); - } - return MessageList(raw_messages.begin(), raw_messages.begin() + result); +std::vector Consumer::poll_batch(size_t max_batch_size, milliseconds timeout) { + return poll_batch(max_batch_size, timeout, allocator()); } Queue Consumer::get_main_queue() const { diff --git a/src/queue.cpp b/src/queue.cpp index 7e220e5a..909fd768 100644 --- a/src/queue.cpp +++ b/src/queue.cpp @@ -32,6 +32,7 @@ using std::vector; using std::exception; using std::chrono::milliseconds; +using std::allocator; namespace cppkafka { @@ -94,25 +95,13 @@ Message Queue::consume(milliseconds timeout) const { return Message(rd_kafka_consume_queue(handle_.get(), static_cast(timeout.count()))); } -MessageList Queue::consume_batch(size_t max_batch_size) const { - return consume_batch(max_batch_size, timeout_ms_); +std::vector Queue::consume_batch(size_t max_batch_size) const { + return consume_batch(max_batch_size, timeout_ms_, allocator()); } -MessageList Queue::consume_batch(size_t max_batch_size, milliseconds timeout) const { - vector raw_messages(max_batch_size); - ssize_t result = rd_kafka_consume_batch_queue(handle_.get(), - static_cast(timeout.count()), - raw_messages.data(), - raw_messages.size()); - if (result == -1) { - rd_kafka_resp_err_t error = rd_kafka_last_error(); - if (error != RD_KAFKA_RESP_ERR_NO_ERROR) { - throw QueueException(error); - } - return MessageList(); - } - // Build message list - return MessageList(raw_messages.begin(), raw_messages.begin() + result); +std::vector Queue::consume_batch(size_t max_batch_size, + milliseconds timeout) const { + return consume_batch(max_batch_size, timeout, allocator()); } } //cppkafka diff --git a/src/utils/roundrobin_poll_strategy.cpp b/src/utils/roundrobin_poll_strategy.cpp index 5d5fc7a6..9ea13cb9 100644 --- a/src/utils/roundrobin_poll_strategy.cpp +++ b/src/utils/roundrobin_poll_strategy.cpp @@ -32,6 +32,7 @@ using std::string; using std::chrono::milliseconds; using std::make_move_iterator; +using std::allocator; namespace cppkafka { @@ -67,46 +68,15 @@ Message RoundRobinPollStrategy::poll(milliseconds timeout) { return get_consumer_queue().queue.consume(timeout); } -MessageList RoundRobinPollStrategy::poll_batch(size_t max_batch_size) { - return poll_batch(max_batch_size, get_consumer().get_timeout()); +std::vector RoundRobinPollStrategy::poll_batch(size_t max_batch_size) { + return poll_batch(max_batch_size, get_consumer().get_timeout(), allocator()); } -MessageList RoundRobinPollStrategy::poll_batch(size_t max_batch_size, milliseconds timeout) { - MessageList messages; - ssize_t count = max_batch_size; - - // batch from the group event queue first (non-blocking) - consume_batch(get_consumer_queue().queue, messages, count, milliseconds(0)); - size_t num_queues = get_partition_queues().size(); - while ((count > 0) && (num_queues--)) { - // batch from the next partition (non-blocking) - consume_batch(get_next_queue().queue, messages, count, milliseconds(0)); - } - // we still have space left in the buffer - if (count > 0) { - // wait on the event queue until timeout - consume_batch(get_consumer_queue().queue, messages, count, timeout); - } - return messages; +std::vector RoundRobinPollStrategy::poll_batch(size_t max_batch_size, + milliseconds timeout) { + return poll_batch(max_batch_size, timeout, allocator()); } -void RoundRobinPollStrategy::consume_batch(Queue& queue, - MessageList& messages, - ssize_t& count, - milliseconds timeout) { - MessageList queue_messages = queue.consume_batch(count, timeout); - if (queue_messages.empty()) { - return; - } - // concatenate both lists - messages.insert(messages.end(), - make_move_iterator(queue_messages.begin()), - make_move_iterator(queue_messages.end())); - // reduce total batch count - count -= queue_messages.size(); -} - - void RoundRobinPollStrategy::restore_forwarding() { // forward all partition queues for (const auto& toppar : get_partition_queues()) { diff --git a/tests/test_utils.h b/tests/test_utils.h index b6943e6e..8b882a2c 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -48,9 +48,9 @@ class PollStrategyAdapter : public Consumer { void delete_polling_strategy(); Message poll(); Message poll(std::chrono::milliseconds timeout); - MessageList poll_batch(size_t max_batch_size); - MessageList poll_batch(size_t max_batch_size, - std::chrono::milliseconds timeout); + std::vector poll_batch(size_t max_batch_size); + std::vector poll_batch(size_t max_batch_size, + std::chrono::milliseconds timeout); void set_timeout(std::chrono::milliseconds timeout); std::chrono::milliseconds get_timeout(); private: diff --git a/tests/test_utils_impl.h b/tests/test_utils_impl.h index e978de24..46b423a8 100644 --- a/tests/test_utils_impl.h +++ b/tests/test_utils_impl.h @@ -19,7 +19,6 @@ using cppkafka::Consumer; using cppkafka::BasicConsumerDispatcher; using cppkafka::Message; -using cppkafka::MessageList; using cppkafka::TopicPartition; //================================================================================== @@ -89,7 +88,7 @@ BasicConsumerRunner::~BasicConsumerRunner() { } template -const MessageList& BasicConsumerRunner::get_messages() const { +const std::vector& BasicConsumerRunner::get_messages() const { return messages_; } @@ -135,7 +134,7 @@ Message PollStrategyAdapter::poll(milliseconds timeout) { } inline -MessageList PollStrategyAdapter::poll_batch(size_t max_batch_size) { +std::vector PollStrategyAdapter::poll_batch(size_t max_batch_size) { if (strategy_) { return strategy_->poll_batch(max_batch_size); } @@ -143,8 +142,8 @@ MessageList PollStrategyAdapter::poll_batch(size_t max_batch_size) { } inline -MessageList PollStrategyAdapter::poll_batch(size_t max_batch_size, - milliseconds timeout) { +std::vector PollStrategyAdapter::poll_batch(size_t max_batch_size, + milliseconds timeout) { if (strategy_) { return strategy_->poll_batch(max_batch_size, timeout); }