Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add batch byte size limit configuration #129

Merged
merged 5 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions batch_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ type batchConsumer struct {
consumeFn BatchConsumeFn
preBatchFn PreBatchFn

messageGroupLimit int
messageGroupLimit int
messageGroupByteSizeLimit int
}

func (b *batchConsumer) Pause() {
Expand All @@ -34,11 +35,17 @@ func newBatchConsumer(cfg *ConsumerConfig) (Consumer, error) {
return nil, err
}

messageGroupByteSizeLimit, err := ResolveUnionIntOrStringValue(cfg.BatchConfiguration.MessageGroupByteSizeLimit)
mhmtszr marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}

c := batchConsumer{
base: consumerBase,
consumeFn: cfg.BatchConfiguration.BatchConsumeFn,
preBatchFn: cfg.BatchConfiguration.PreBatchFn,
messageGroupLimit: cfg.BatchConfiguration.MessageGroupLimit,
base: consumerBase,
consumeFn: cfg.BatchConfiguration.BatchConsumeFn,
preBatchFn: cfg.BatchConfiguration.PreBatchFn,
messageGroupLimit: cfg.BatchConfiguration.MessageGroupLimit,
messageGroupByteSizeLimit: messageGroupByteSizeLimit,
}

if cfg.RetryEnabled {
Expand Down Expand Up @@ -86,29 +93,35 @@ func (b *batchConsumer) startBatch() {
defer ticker.Stop()

maximumMessageLimit := b.messageGroupLimit * b.concurrency
maximumMessageByteSizeLimit := b.messageGroupByteSizeLimit * b.concurrency
messages := make([]*Message, 0, maximumMessageLimit)
commitMessages := make([]kafka.Message, 0, maximumMessageLimit)

messageByteSize := 0
for {
select {
case <-ticker.C:
if len(messages) == 0 {
continue
}

b.consume(&messages, &commitMessages)
b.consume(&messages, &commitMessages, &messageByteSize)
case msg, ok := <-b.incomingMessageStream:
if !ok {
close(b.batchConsumingStream)
close(b.messageProcessedStream)
return
}

if maximumMessageByteSizeLimit != 0 && messageByteSize+len(msg.message.Value) > maximumMessageByteSizeLimit {
mhmtszr marked this conversation as resolved.
Show resolved Hide resolved
mhmtszr marked this conversation as resolved.
Show resolved Hide resolved
b.consume(&messages, &commitMessages, &messageByteSize)
}

messages = append(messages, msg.message)
commitMessages = append(commitMessages, *msg.kafkaMessage)
messageByteSize += len(msg.message.Value)

if len(messages) == maximumMessageLimit {
b.consume(&messages, &commitMessages)
b.consume(&messages, &commitMessages, &messageByteSize)
}
}
}
Expand All @@ -126,31 +139,50 @@ func (b *batchConsumer) setupConcurrentWorkers() {
}
}

func chunkMessages(allMessages *[]*Message, chunkSize int) [][]*Message {
func chunkMessages(allMessages *[]*Message, chunkSize int, chunkByteSize int) [][]*Message {
var chunks [][]*Message

allMessageList := *allMessages
for i := 0; i < len(allMessageList); i += chunkSize {
end := i + chunkSize

// necessary check to avoid slicing beyond
// slice capacity
if end > len(allMessageList) {
end = len(allMessageList)
var currentChunk []*Message
currentChunkSize := 0
currentChunkBytes := 0

for _, message := range allMessageList {
mhmtszr marked this conversation as resolved.
Show resolved Hide resolved
messageByteSize := len(message.Value)

// Check if adding this message would exceed either the chunk size or the byte size
if len(currentChunk) >= chunkSize || (chunkByteSize != 0 && currentChunkBytes+messageByteSize > chunkByteSize) {
// Avoid too low chunkByteSize
if len(currentChunk) == 0 {
panic("invalid chunk byte size, please increase it")
}
// If it does, finalize the current chunk and start a new one
chunks = append(chunks, currentChunk)
currentChunk = []*Message{}
currentChunkSize = 0
currentChunkBytes = 0
}

chunks = append(chunks, allMessageList[i:end])
// Add the message to the current chunk
currentChunk = append(currentChunk, message)
currentChunkSize++
currentChunkBytes += messageByteSize
}

// Add the last chunk if it has any messages
if len(currentChunk) > 0 {
chunks = append(chunks, currentChunk)
}

return chunks
}

func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message) {
chunks := chunkMessages(allMessages, b.messageGroupLimit)
func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka.Message, messageByteSizeLimit *int) {
chunks := chunkMessages(allMessages, b.messageGroupLimit, b.messageGroupByteSizeLimit)

if b.preBatchFn != nil {
preBatchResult := b.preBatchFn(*allMessages)
chunks = chunkMessages(&preBatchResult, b.messageGroupLimit)
chunks = chunkMessages(&preBatchResult, b.messageGroupLimit, b.messageGroupByteSizeLimit)
}

// Send the messages to process
Expand All @@ -170,6 +202,7 @@ func (b *batchConsumer) consume(allMessages *[]*Message, commitMessages *[]kafka
// Clearing resources
*commitMessages = (*commitMessages)[:0]
*allMessages = (*allMessages)[:0]
*messageByteSizeLimit = 0
}

func (b *batchConsumer) process(chunkMessages []*Message) {
Expand Down
48 changes: 33 additions & 15 deletions batch_consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,52 +301,69 @@ func Test_batchConsumer_process(t *testing.T) {

func Test_batchConsumer_chunk(t *testing.T) {
tests := []struct {
allMessages []*Message
expected [][]*Message
chunkSize int
allMessages []*Message
expected [][]*Message
chunkSize int
chunkByteSize int
}{
{
allMessages: createMessages(0, 9),
chunkSize: 3,
allMessages: createMessages(0, 9),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 3),
createMessages(3, 6),
createMessages(6, 9),
},
},
{
allMessages: []*Message{},
chunkSize: 3,
expected: [][]*Message{},
allMessages: []*Message{},
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{},
},
{
allMessages: createMessages(0, 1),
chunkSize: 3,
allMessages: createMessages(0, 1),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 1),
},
},
{
allMessages: createMessages(0, 8),
chunkSize: 3,
allMessages: createMessages(0, 8),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 3),
createMessages(3, 6),
createMessages(6, 8),
},
},
{
allMessages: createMessages(0, 3),
chunkSize: 3,
allMessages: createMessages(0, 3),
chunkSize: 3,
chunkByteSize: 10000,
expected: [][]*Message{
createMessages(0, 3),
},
},

{
allMessages: createMessages(0, 3),
chunkSize: 100,
chunkByteSize: 4,
expected: [][]*Message{
createMessages(0, 1),
createMessages(1, 2),
createMessages(2, 3),
},
},
}

for i, tc := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize)
chunkedMessages := chunkMessages(&tc.allMessages, tc.chunkSize, tc.chunkByteSize)

if !reflect.DeepEqual(chunkedMessages, tc.expected) && !(len(chunkedMessages) == 0 && len(tc.expected) == 0) {
t.Errorf("For chunkSize %d, expected %v, but got %v", tc.chunkSize, tc.expected, chunkedMessages)
Expand Down Expand Up @@ -444,6 +461,7 @@ func createMessages(partitionStart int, partitionEnd int) []*Message {
for i := partitionStart; i < partitionEnd; i++ {
messages = append(messages, &Message{
Partition: i,
Value: []byte("test"),
})
}
return messages
Expand Down
7 changes: 4 additions & 3 deletions consumer_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ type RetryConfiguration struct {
}

type BatchConfiguration struct {
BatchConsumeFn BatchConsumeFn
PreBatchFn PreBatchFn
MessageGroupLimit int
BatchConsumeFn BatchConsumeFn
PreBatchFn PreBatchFn
MessageGroupLimit int
MessageGroupByteSizeLimit any
}

func (cfg *ConsumerConfig) newKafkaDialer() (*kafka.Dialer, error) {
Expand Down
63 changes: 63 additions & 0 deletions data_units.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package kafka

import (
"fmt"
"strconv"
"strings"
)

func ResolveUnionIntOrStringValue(input any) (int, error) {
switch value := input.(type) {
case int:
return value, nil
case uint:
return int(value), nil
case nil:
return 0, nil
case string:
intValue, err := strconv.ParseInt(value, 10, 64)
if err == nil {
return int(intValue), nil
Dismissed Show dismissed Hide dismissed
}

result, err := convertSizeUnitToByte(value)
if err != nil {
return 0, err
}

return result, nil
}

return 0, fmt.Errorf("invalid input: %v", input)
}

func convertSizeUnitToByte(str string) (int, error) {
if len(str) < 2 {
return 0, fmt.Errorf("invalid input: %s", str)
}

// Extract the numeric part of the input
sizeStr := str[:len(str)-2]
sizeStr = strings.TrimSpace(sizeStr)
sizeStr = strings.ReplaceAll(sizeStr, ",", ".")

size, err := strconv.ParseFloat(sizeStr, 64)
if err != nil {
return 0, fmt.Errorf("cannot extract numeric part for the input %s, err = %w", str, err)
}

// Determine the unit (B, KB, MB, GB)
unit := str[len(str)-2:]
switch strings.ToUpper(unit) {
case "B":
return int(size), nil
case "KB":
return int(size * 1024), nil
case "MB":
return int(size * 1024 * 1024), nil
case "GB":
return int(size * 1024 * 1024 * 1024), nil
default:
return 0, fmt.Errorf("unsupported unit: %s, you can specify one of B, KB, MB and GB", unit)
}
}
85 changes: 85 additions & 0 deletions data_units_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package kafka

import "testing"

func TestDcp_ResolveConnectionBufferSize(t *testing.T) {
tests := []struct {
input any
name string
want int
}{
{
name: "When_Client_Gives_Int_Value",
input: 20971520,
want: 20971520,
},
{
name: "When_Client_Gives_UInt_Value",
input: uint(10971520),
want: 10971520,
},
{
name: "When_Client_Gives_StringInt_Value",
input: "15971520",
want: 15971520,
},
{
name: "When_Client_Gives_KB_Value",
input: "500kb",
want: 500 * 1024,
},
{
name: "When_Client_Gives_MB_Value",
input: "10mb",
want: 10 * 1024 * 1024,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got, _ := ResolveUnionIntOrStringValue(tt.input); got != tt.want {
t.Errorf("ResolveConnectionBufferSize() = %v, want %v", got, tt.want)
}
})
}
}

func TestConvertToBytes(t *testing.T) {
testCases := []struct {
input string
expected int
err bool
}{
{"1kb", 1024, false},
{"5mb", 5 * 1024 * 1024, false},
{"5,5mb", 5.5 * 1024 * 1024, false},
{"8.5mb", 8.5 * 1024 * 1024, false},
{"10,25 mb", 10.25 * 1024 * 1024, false},
{"10gb", 10 * 1024 * 1024 * 1024, false},
{"1KB", 1024, false},
{"5MB", 5 * 1024 * 1024, false},
{"12 MB", 12 * 1024 * 1024, false},
{"10GB", 10 * 1024 * 1024 * 1024, false},
{"123", 0, true},
{"15TB", 0, true},
{"invalid", 0, true},
{"", 0, true},
{"123 KB", 123 * 1024, false},
{"1 MB", 1 * 1024 * 1024, false},
}

for _, tc := range testCases {
result, err := convertSizeUnitToByte(tc.input)

if tc.err && err == nil {
t.Errorf("Expected an error for input %s, but got none", tc.input)
}

if !tc.err && err != nil {
t.Errorf("Unexpected error for input %s: %v", tc.input, err)
}

if result != tc.expected {
t.Errorf("For input %s, expected %d bytes, but got %d", tc.input, tc.expected, result)
}
}
}
Loading
Loading