Skip to content

Commit

Permalink
feat: add batch byte size limit configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
mhmtszr committed May 15, 2024
1 parent 6e542f1 commit ec0ebc6
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 40 deletions.
73 changes: 53 additions & 20 deletions batch_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ type batchConsumer struct {
consumeFn BatchConsumeFn
preBatchFn PreBatchFn

messageGroupLimit int
messageGroupLimit int
messageGroupByteSizeLimit int
}

func (b *batchConsumer) Pause() {
Expand All @@ -34,11 +35,17 @@ func newBatchConsumer(cfg *ConsumerConfig) (Consumer, error) {
return nil, err
}

messageGroupByteSizeLimit, err := ResolveUnionIntOrStringValue(cfg.BatchConfiguration.MessageGroupByteSizeLimit)
if err != nil {
return nil, err
}

c := batchConsumer{
base: consumerBase,
consumeFn: cfg.BatchConfiguration.BatchConsumeFn,
preBatchFn: cfg.BatchConfiguration.PreBatchFn,
messageGroupLimit: cfg.BatchConfiguration.MessageGroupLimit,
base: consumerBase,
consumeFn: cfg.BatchConfiguration.BatchConsumeFn,
preBatchFn: cfg.BatchConfiguration.PreBatchFn,
messageGroupLimit: cfg.BatchConfiguration.MessageGroupLimit,
messageGroupByteSizeLimit: messageGroupByteSizeLimit,
}

if cfg.RetryEnabled {
Expand Down Expand Up @@ -86,29 +93,35 @@ func (b *batchConsumer) startBatch() {
defer ticker.Stop()

maximumMessageLimit := b.messageGroupLimit * b.concurrency
maximumMessageByteSizeLimit := b.messageGroupByteSizeLimit * b.concurrency
messages := make([]*Message, 0, maximumMessageLimit)
commitMessages := make([]kafka.Message, 0, maximumMessageLimit)

messageByteSize := 0
for {
select {
case <-ticker.C:
if len(messages) == 0 {
continue
}

b.consume(&messages, &commitMessages)
b.consume(&messages, &commitMessages, &messageByteSize)
case msg, ok := <-b.incomingMessageStream:
if !ok {
close(b.batchConsumingStream)
close(b.messageProcessedStream)
return
}

if maximumMessageByteSizeLimit != 0 && messageByteSize+len(msg.message.Value) > maximumMessageByteSizeLimit {
b.consume(&messages, &commitMessages, &messageByteSize)
}

messages = append(messages, msg.message)
commitMessages = append(commitMessages, *msg.kafkaMessage)
messageByteSize += len(msg.message.Value)

if len(messages) == maximumMessageLimit {
b.consume(&messages, &commitMessages)
b.consume(&messages, &commitMessages, &messageByteSize)
}
}
}
Expand All @@ -126,31 +139,50 @@ func (b *batchConsumer) setupConcurrentWorkers() {
}
}

func chunkMessages(allMessages *[]*Message, chunkSize int) [][]*Message {
func chunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [][]*Message {
var chunks [][]*Message

allMessageList := *allMessages
for i := 0; i < len(allMessageList); i += chunkSize {
end := i + chunkSize

// necessary check to avoid slicing beyond
// slice capacity
if end > len(allMessageList) {
end = len(allMessageList)
var currentChunk []*Message
currentChunkSize := 0
currentChunkBytes := 0

for _, message := range allMessageList {
messageByteSize := len(message.Value)

// Check if adding this message would exceed either the chunk size or the byte size
if len(currentChunk) >= chunkSize || (chunkByteSize != 0 && currentChunkBytes+messageByteSize > chunkByteSize) {
// Avoid too low chunkByteSize
if len(currentChunk) == 0 {
panic("invalid chunk byte size, please increase it")
}
// If it does, finalize the current chunk and start a new one
chunks = append(chunks, currentChunk)
currentChunk = []*Message{}
currentChunkSize = 0
currentChunkBytes = 0
}

chunks = append(chunks, allMessageList[i:end])
// Add the message to the current chunk
currentChunk = append(currentChunk, message)
currentChunkSize++
currentChunkBytes += messageByteSize
}

// Add the last chunk if it has any messages
if len(currentChunk) > 0 {
chunks = append(chunks, currentChunk)
}

return chunks
}

func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message) {
chunks := chunkMessages(allMessages, b.messageGroupLimit)
func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message, messageByteSizeLimit *int) {
chunks := chunkMessages(allMessages, b.messageGroupLimit, b.messageGroupByteSizeLimit)

if b.preBatchFn != nil {
preBatchResult := b.preBatchFn(*allMessages)
chunks = chunkMessages(&preBatchResult, b.messageGroupLimit)
chunks = chunkMessages(&preBatchResult, b.messageGroupLimit, b.messageGroupByteSizeLimit)
}

// Send the messages to process
Expand All @@ -170,6 +202,7 @@ func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka
// Clearing resources
*commitMessages = (*commitMessages)[:0]
*allMessages = (*allMessages)[:0]
*messageByteSizeLimit = 0
}

func (b *batchConsumer) process(chunkMessages []*Message) {
Expand Down
48 changes: 33 additions & 15 deletions batch_consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,52 +301,69 @@ func Test_batchConsumer_process(t *testing.T) {

func Test_batchConsumer_chunk(t *testing.T) {
tests := []struct {
allMessages []*Message
expected [][]*Message
chunkSize int
allMessages []*Message
expected [][]*Message
chunkSize int
chunkByteSize int
}{
{
allMessages: createMessages(0, 9),
chunkSize: 3,
allMessages: createMessages(0, 9),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 3),
createMessages(3, 6),
createMessages(6, 9),
},
},
{
allMessages: []*Message{},
chunkSize: 3,
expected: [][]*Message{},
allMessages: []*Message{},
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{},
},
{
allMessages: createMessages(0, 1),
chunkSize: 3,
allMessages: createMessages(0, 1),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 1),
},
},
{
allMessages: createMessages(0, 8),
chunkSize: 3,
allMessages: createMessages(0, 8),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 3),
createMessages(3, 6),
createMessages(6, 8),
},
},
{
allMessages: createMessages(0, 3),
chunkSize: 3,
allMessages: createMessages(0, 3),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 3),
},
},

{
allMessages: createMessages(0, 3),
chunkSize: 100,
chunkByteSize: 4,
expected: [][]*Message{
createMessages(0, 1),
createMessages(1, 2),
createMessages(2, 3),
},
},
}

for i, tc := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize)
chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize, tc.chunkByteSize)

if !reflect.DeepEqual(chunkedMessages, tc.expected) && !(len(chunkedMessages) == 0 && len(tc.expected) == 0) {
t.Errorf("For chunkSize %d, expected %v, but got %v", tc.chunkSize, tc.expected, chunkedMessages)
Expand Down Expand Up @@ -444,6 +461,7 @@ func createMessages(partitionStart int, partitionEnd int) []*Message {
for i := partitionStart; i < partitionEnd; i++ {
messages = append(messages, &Message{
Partition: i,
Value: []byte("test"),
})
}
return messages
Expand Down
7 changes: 4 additions & 3 deletions consumer_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ type RetryConfiguration struct {
}

type BatchConfiguration struct {
BatchConsumeFn BatchConsumeFn
PreBatchFn PreBatchFn
MessageGroupLimit int
BatchConsumeFn BatchConsumeFn
PreBatchFn PreBatchFn
MessageGroupLimit int
MessageGroupByteSizeLimit any
}

func (cfg *ConsumerConfig) newKafkaDialer() (*kafka.Dialer, error) {
Expand Down
63 changes: 63 additions & 0 deletions data_units.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package kafka

import (
"fmt"
"strconv"
"strings"
)

func ResolveUnionIntOrStringValue(input any) (int, error) {
switch value := input.(type) {
case int:
return value, nil
case uint:
return int(value), nil
case nil:
return 0, nil
case string:
intValue, err := strconv.ParseInt(value, 10, 64)
if err == nil {
return int(intValue), nil

Check failure

Code scanning / CodeQL

Incorrect conversion between integer types High

Incorrect conversion of a signed 64-bit integer from
strconv.ParseInt
to a lower bit size type int without an upper bound check.
}

result, err := convertSizeUnitToByte(value)
if err != nil {
return 0, err
}

return result, nil
}

return 0, fmt.Errorf("invalid input: %v", input)
}

func convertSizeUnitToByte(str string) (int, error) {
if len(str) < 2 {
return 0, fmt.Errorf("invalid input: %s", str)
}

// Extract the numeric part of the input
sizeStr := str[:len(str)-2]
sizeStr = strings.TrimSpace(sizeStr)
sizeStr = strings.ReplaceAll(sizeStr, ",", ".")

size, err := strconv.ParseFloat(sizeStr, 64)
if err != nil {
return 0, fmt.Errorf("cannot extract numeric part for the input %s, err = %w", str, err)
}

// Determine the unit (B, KB, MB, GB)
unit := str[len(str)-2:]
switch strings.ToUpper(unit) {
case "B":
return int(size), nil
case "KB":
return int(size * 1024), nil
case "MB":
return int(size * 1024 * 1024), nil
case "GB":
return int(size * 1024 * 1024 * 1024), nil
default:
return 0, fmt.Errorf("unsupported unit: %s, you can specify one of B, KB, MB and GB", unit)
}
}
85 changes: 85 additions & 0 deletions data_units_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package kafka

import "testing"

func TestDcp_ResolveConnectionBufferSize(t *testing.T) {
tests := []struct {
input any
name string
want int
}{
{
name: "When_Client_Gives_Int_Value",
input: 20971520,
want: 20971520,
},
{
name: "When_Client_Gives_UInt_Value",
input: uint(10971520),
want: 10971520,
},
{
name: "When_Client_Gives_StringInt_Value",
input: "15971520",
want: 15971520,
},
{
name: "When_Client_Gives_KB_Value",
input: "500kb",
want: 500 * 1024,
},
{
name: "When_Client_Gives_MB_Value",
input: "10mb",
want: 10 * 1024 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got, _ := ResolveUnionIntOrStringValue(tt.input); got != tt.want {
t.Errorf("ResolveConnectionBufferSize() = %v, want %v", got, tt.want)
}
})
}
}

func TestConvertToBytes(t *testing.T) {
testCases := []struct {
input string
expected int
err bool
}{
{"1kb", 1024, false},
{"5mb", 5 * 1024 * 1024, false},
{"5,5mb", 5.5 * 1024 * 1024, false},
{"8.5mb", 8.5 * 1024 * 1024, false},
{"10,25 mb", 10.25 * 1024 * 1024, false},
{"10gb", 10 * 1024 * 1024 * 1024, false},
{"1KB", 1024, false},
{"5MB", 5 * 1024 * 1024, false},
{"12 MB", 12 * 1024 * 1024, false},
{"10GB", 10 * 1024 * 1024 * 1024, false},
{"123", 0, true},
{"15TB", 0, true},
{"invalid", 0, true},
{"", 0, true},
{"123 KB", 123 * 1024, false},
{"1 MB", 1 * 1024 * 1024, false},
}

for _, tc := range testCases {
result, err := convertSizeUnitToByte(tc.input)

if tc.err && err == nil {
t.Errorf("Expected an error for input %s, but got none", tc.input)
}

if !tc.err && err != nil {
t.Errorf("Unexpected error for input %s: %v", tc.input, err)
}

if result != tc.expected {
t.Errorf("For input %s, expected %d bytes, but got %d", tc.input, tc.expected, result)
}
}
}
Loading

0 comments on commit ec0ebc6

Please sign in to comment.