From 889404791d9125d2d17a7da963ae0fc2b268e2e7 Mon Sep 17 00:00:00 2001 From: Cyril Tovena Date: Tue, 14 Jan 2025 16:18:10 +0100 Subject: [PATCH] feat(kafka): Add cooperative active sticky balancer (#15706) --- pkg/kafka/partitionring/consumer/balancer.go | 159 ++++++ .../partitionring/consumer/balancer_test.go | 481 ++++++++++++++++++ pkg/kafka/partitionring/consumer/client.go | 105 ++++ .../partitionring/consumer/client_test.go | 321 ++++++++++++ 4 files changed, 1066 insertions(+) create mode 100644 pkg/kafka/partitionring/consumer/balancer.go create mode 100644 pkg/kafka/partitionring/consumer/balancer_test.go create mode 100644 pkg/kafka/partitionring/consumer/client.go create mode 100644 pkg/kafka/partitionring/consumer/client_test.go diff --git a/pkg/kafka/partitionring/consumer/balancer.go b/pkg/kafka/partitionring/consumer/balancer.go new file mode 100644 index 0000000000000..9ac361277484a --- /dev/null +++ b/pkg/kafka/partitionring/consumer/balancer.go @@ -0,0 +1,159 @@ +package consumer + +import ( + "sort" + + "github.com/grafana/dskit/ring" + "github.com/twmb/franz-go/pkg/kgo" + "github.com/twmb/franz-go/pkg/kmsg" +) + +type cooperativeActiveStickyBalancer struct { + kgo.GroupBalancer + partitionRing ring.PartitionRingReader +} + +// NewCooperativeActiveStickyBalancer creates a balancer that combines Kafka's cooperative sticky balancing +// with partition ring awareness. It works by: +// +// 1. Using the partition ring to determine which partitions are "active" (i.e. should be processed) +// 2. Filtering out inactive partitions from member assignments during rebalancing, but still assigning them +// 3. Applying cooperative sticky balancing only to the active partitions +// +// This ensures that: +// - Active partitions are balanced evenly across consumers using sticky assignment for optimal processing +// - Inactive partitions are still assigned and consumed in a round-robin fashion, but without sticky assignment +// - All partitions are monitored even if inactive, allowing quick activation when needed +// - Partition handoff happens cooperatively to avoid stop-the-world rebalances +// +// This balancer should be used with [NewGroupClient] which monitors the partition ring and triggers +// rebalancing when the set of active partitions changes. This ensures optimal partition distribution +// as the active partition set evolves. +func NewCooperativeActiveStickyBalancer(partitionRing ring.PartitionRingReader) kgo.GroupBalancer { + return &cooperativeActiveStickyBalancer{ + GroupBalancer: kgo.CooperativeStickyBalancer(), + partitionRing: partitionRing, + } +} + +func (*cooperativeActiveStickyBalancer) ProtocolName() string { + return "cooperative-active-sticky" +} + +func (b *cooperativeActiveStickyBalancer) MemberBalancer(members []kmsg.JoinGroupResponseMember) (kgo.GroupMemberBalancer, map[string]struct{}, error) { + // Get active partitions from ring + activePartitions := make(map[int32]struct{}) + for _, id := range b.partitionRing.PartitionRing().PartitionIDs() { + activePartitions[id] = struct{}{} + } + + // Filter member metadata to only include active partitions + filteredMembers := make([]kmsg.JoinGroupResponseMember, len(members)) + for i, member := range members { + var meta kmsg.ConsumerMemberMetadata + err := meta.ReadFrom(member.ProtocolMetadata) + if err != nil { + continue + } + + // Filter owned partitions to only include active ones + filteredOwned := make([]kmsg.ConsumerMemberMetadataOwnedPartition, 0, len(meta.OwnedPartitions)) + for _, owned := range meta.OwnedPartitions { + filtered := kmsg.ConsumerMemberMetadataOwnedPartition{ + Topic: owned.Topic, + Partitions: make([]int32, 0, len(owned.Partitions)), + } + for _, p := range owned.Partitions { + if _, isActive := activePartitions[p]; isActive { + filtered.Partitions = append(filtered.Partitions, p) + } + } + if len(filtered.Partitions) > 0 { + filteredOwned = append(filteredOwned, filtered) + } + } + meta.OwnedPartitions = filteredOwned + + // Create filtered member + filteredMembers[i] = kmsg.JoinGroupResponseMember{ + MemberID: member.MemberID, + ProtocolMetadata: meta.AppendTo(nil), + } + } + + balancer, err := kgo.NewConsumerBalancer(b, filteredMembers) + return balancer, balancer.MemberTopics(), err +} + +// syncAssignments implements kgo.IntoSyncAssignment +type syncAssignments []kmsg.SyncGroupRequestGroupAssignment + +func (s syncAssignments) IntoSyncAssignment() []kmsg.SyncGroupRequestGroupAssignment { + return s +} + +func (b *cooperativeActiveStickyBalancer) Balance(balancer *kgo.ConsumerBalancer, topics map[string]int32) kgo.IntoSyncAssignment { + // Get active partition count + actives := b.partitionRing.PartitionRing().PartitionsCount() + + // First, let the sticky balancer handle active partitions + activeTopics := make(map[string]int32) + inactiveTopics := make(map[string]int32) + for topic, total := range topics { + activeTopics[topic] = int32(actives) + if total > int32(actives) { + inactiveTopics[topic] = total - int32(actives) + } + } + + // Get active partition assignment + assignment := b.GroupBalancer.(kgo.ConsumerBalancerBalance).Balance(balancer, activeTopics) + + plan := assignment.IntoSyncAssignment() + + // Get sorted list of members for deterministic round-robin + members := make([]string, 0, len(plan)) + for _, m := range plan { + members = append(members, m.MemberID) + } + sort.Strings(members) + + // Distribute inactive partitions round-robin + memberIdx := 0 + for topic, numInactive := range inactiveTopics { + for p := int32(actives); p < int32(actives)+numInactive; p++ { + // Find the member's assignment + for i, m := range plan { + if m.MemberID == members[memberIdx] { + var meta kmsg.ConsumerMemberAssignment + err := meta.ReadFrom(m.MemberAssignment) + if err != nil { + continue + } + + // Find or create topic assignment + found := false + for j, t := range meta.Topics { + if t.Topic == topic { + meta.Topics[j].Partitions = append(t.Partitions, p) + found = true + break + } + } + if !found { + meta.Topics = append(meta.Topics, kmsg.ConsumerMemberAssignmentTopic{ + Topic: topic, + Partitions: []int32{p}, + }) + } + + plan[i].MemberAssignment = meta.AppendTo(nil) + break + } + } + memberIdx = (memberIdx + 1) % len(members) + } + } + + return syncAssignments(plan) +} diff --git a/pkg/kafka/partitionring/consumer/balancer_test.go b/pkg/kafka/partitionring/consumer/balancer_test.go new file mode 100644 index 0000000000000..227e606af1a52 --- /dev/null +++ b/pkg/kafka/partitionring/consumer/balancer_test.go @@ -0,0 +1,481 @@ +package consumer + +import ( + "context" + "sort" + "testing" + "time" + + "github.com/grafana/dskit/ring" + "github.com/stretchr/testify/require" + "github.com/twmb/franz-go/pkg/kadm" + "github.com/twmb/franz-go/pkg/kfake" + "github.com/twmb/franz-go/pkg/kgo" + "github.com/twmb/franz-go/pkg/kmsg" +) + +// Helper types for testing +type memberUpdate struct { + memberID string + topics map[string][]int32 +} + +type mockPartitionRing struct { + partitionIDs []int32 +} + +// Mock implementation of PartitionRing interface +func (m *mockPartitionRing) PartitionIDs() []int32 { + return m.partitionIDs +} + +type mockPartitionRingReader struct { + ring *mockPartitionRing +} + +func (m *mockPartitionRingReader) PartitionRing() *ring.PartitionRing { + desc := ring.PartitionRingDesc{ + Partitions: make(map[int32]ring.PartitionDesc), + } + for _, id := range m.ring.partitionIDs { + desc.Partitions[id] = ring.PartitionDesc{ + Id: id, + State: ring.PartitionActive, + Tokens: []uint32{uint32(id)}, // Use partition ID as token for simplicity + } + } + return ring.NewPartitionRing(desc) +} + +func TestCooperativeActiveStickyBalancer(t *testing.T) { + type memberState struct { + id string + currentPartitions []int32 // nil means new member + } + + type memberResult struct { + id string + partitions []int32 + } + + type testCase struct { + name string + activePartitions []int32 + totalPartitions int32 + members []memberState + expected []memberResult // expected partition assignments per member + } + + tests := []testCase{ + { + name: "initial assignment with two members", + activePartitions: []int32{0, 1, 2}, + totalPartitions: 6, + members: []memberState{ + {id: "member-1"}, + {id: "member-2"}, + }, + expected: []memberResult{ + {id: "member-1", partitions: []int32{0, 3, 5}}, // 1 active (0), 2 inactive (3,5) + {id: "member-2", partitions: []int32{1, 2, 4}}, // 2 active (1,2), 1 inactive (4) + }, + }, + { + name: "rebalance when adding third member", + activePartitions: []int32{0, 1, 2}, + totalPartitions: 6, + members: []memberState{ + {id: "member-1", currentPartitions: []int32{0, 3, 5}}, + {id: "member-2", currentPartitions: []int32{1, 2, 4}}, + {id: "member-3"}, + }, + expected: []memberResult{ + {id: "member-1", partitions: []int32{0, 3}}, // keeps active 0, keeps inactive 3 + {id: "member-2", partitions: []int32{1, 4}}, // keeps active 1, keeps inactive 4 + {id: "member-3", partitions: []int32{2, 5}}, // gets active 2, gets inactive 5 + }, + }, + { + name: "complex rebalance with more partitions", + activePartitions: []int32{0, 1, 2, 3, 4}, + totalPartitions: 10, + members: []memberState{ + {id: "member-1", currentPartitions: []int32{0, 1, 5, 6}}, + {id: "member-2", currentPartitions: []int32{2, 3, 7, 8}}, + {id: "member-3", currentPartitions: []int32{4, 9}}, + {id: "member-4"}, + }, + expected: []memberResult{ + {id: "member-1", partitions: []int32{0, 5, 9}}, // keeps active 0, keeps inactive 5, gets inactive 9 + {id: "member-2", partitions: []int32{2, 3, 6}}, // keeps active 2,3, gets inactive 6 + {id: "member-3", partitions: []int32{4, 7}}, // keeps active 4, gets inactive 7 + {id: "member-4", partitions: []int32{1, 8}}, // gets active 1,gets inactive 8 + }, + }, + { + name: "member leaves with many partitions", + activePartitions: []int32{0, 1, 2, 3, 4, 5}, + totalPartitions: 12, + members: []memberState{ + {id: "member-1", currentPartitions: []int32{0, 1, 6, 7}}, + {id: "member-2", currentPartitions: []int32{2, 3, 8, 9}}, + // member-3 left, had partitions: [4, 5, 10, 11] + }, + expected: []memberResult{ + {id: "member-1", partitions: []int32{0, 1, 4, 6, 8, 10}}, // keeps active 0,1, gets active 4, keeps inactive 6, gets inactive 8,10 + {id: "member-2", partitions: []int32{2, 3, 5, 7, 9, 11}}, // keeps active 2,3, gets active 5, gets inactive 7, keeps inactive 9, gets inactive 11 + }, + }, + { + name: "all members leave except one", + activePartitions: []int32{0, 1, 2, 3}, + totalPartitions: 8, + members: []memberState{ + {id: "member-1", currentPartitions: []int32{0, 4}}, + // member-2 left, had [1, 5] + // member-3 left, had [2, 6] + // member-4 left, had [3, 7] + }, + expected: []memberResult{ + {id: "member-1", partitions: []int32{0, 1, 2, 3, 4, 5, 6, 7}}, // gets all partitions + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup mock ring + mockRing := &mockPartitionRing{partitionIDs: tc.activePartitions} + mockReader := &mockPartitionRingReader{ring: mockRing} + balancer := NewCooperativeActiveStickyBalancer(mockReader) + + // First rebalance: members announce what they want to give up + members := make([]kmsg.JoinGroupResponseMember, len(tc.members)) + for i, m := range tc.members { + var currentAssignment map[string][]int32 + if m.currentPartitions != nil { + currentAssignment = map[string][]int32{"topic-1": m.currentPartitions} + } + members[i] = kmsg.JoinGroupResponseMember{ + MemberID: m.id, + ProtocolMetadata: createMemberMetadata(t, []string{"topic-1"}, currentAssignment), + } + } + + memberBalancer, _, err := balancer.MemberBalancer(members) + require.NoError(t, err) + assignment, err := memberBalancer.(kgo.GroupMemberBalancerOrError).BalanceOrError(map[string]int32{"topic-1": tc.totalPartitions}) + require.NoError(t, err) + plan := assignment.IntoSyncAssignment() + + // Get current assignments after first rebalance + currentAssignments := make(map[string][]int32) + for _, m := range plan { + partitions := extractPartitions(t, plan, m.MemberID, "topic-1") + if len(partitions) > 0 { + currentAssignments[m.MemberID] = partitions + } + } + + // Second rebalance: members can take new partitions + members = make([]kmsg.JoinGroupResponseMember, len(tc.members)) + for i, m := range tc.members { + currentPartitions := currentAssignments[m.id] + if currentPartitions == nil && m.currentPartitions != nil { + currentPartitions = m.currentPartitions + } + members[i] = kmsg.JoinGroupResponseMember{ + MemberID: m.id, + ProtocolMetadata: createMemberMetadata(t, []string{"topic-1"}, map[string][]int32{"topic-1": currentPartitions}), + } + } + + memberBalancer, _, err = balancer.MemberBalancer(members) + require.NoError(t, err) + assignment, err = memberBalancer.(kgo.GroupMemberBalancerOrError).BalanceOrError(map[string]int32{"topic-1": tc.totalPartitions}) + require.NoError(t, err) + plan = assignment.IntoSyncAssignment() + + // Verify final assignments for each member + for _, expected := range tc.expected { + actual := extractPartitions(t, plan, expected.id, "topic-1") + sort.Slice(actual, func(i, j int) bool { return actual[i] < actual[j] }) + sort.Slice(expected.partitions, func(i, j int) bool { return expected.partitions[i] < expected.partitions[j] }) + + require.Equal(t, expected.partitions, actual, + "Member %s got wrong partition assignment.\nExpected: %v (active: %v)\nGot: %v (active: %v)", + expected.id, + expected.partitions, countActivePartitions(expected.partitions, tc.activePartitions), + actual, countActivePartitions(actual, tc.activePartitions)) + } + }) + } +} + +// Test helpers for consumer group testing +type testConsumerGroup struct { + t *testing.T + admin *kadm.Client + mockRing *mockPartitionRing + mockReader *mockPartitionRingReader + groupName string + clusterAddrs []string +} + +func newTestConsumerGroup(t *testing.T, numPartitions int) *testConsumerGroup { + // Create a fake cluster + cluster := kfake.MustCluster( + kfake.NumBrokers(2), + kfake.SeedTopics(int32(numPartitions), "test-topic"), + ) + t.Cleanup(func() { cluster.Close() }) + + addrs := cluster.ListenAddrs() + require.NotEmpty(t, addrs) + + // Create admin client + admClient, err := kgo.NewClient( + kgo.SeedBrokers(addrs...), + ) + require.NoError(t, err) + t.Cleanup(func() { admClient.Close() }) + + admin := kadm.NewClient(admClient) + t.Cleanup(func() { admin.Close() }) + + // Create mock ring with first 3 partitions active + mockRing := &mockPartitionRing{partitionIDs: []int32{0, 1, 2}} + mockReader := &mockPartitionRingReader{ring: mockRing} + + return &testConsumerGroup{ + t: t, + admin: admin, + mockRing: mockRing, + mockReader: mockReader, + groupName: "test-group", + clusterAddrs: addrs, + } +} + +func (g *testConsumerGroup) createConsumer(id string) *kgo.Client { + client, err := kgo.NewClient( + kgo.SeedBrokers(g.clusterAddrs...), + kgo.ConsumerGroup(g.groupName), + kgo.ConsumeTopics("test-topic"), + kgo.Balancers(NewCooperativeActiveStickyBalancer(g.mockReader)), + kgo.ClientID(id), + kgo.OnPartitionsAssigned(func(_ context.Context, _ *kgo.Client, m map[string][]int32) { + g.t.Logf("Assigned partitions 1: %v", m) + }), + kgo.OnPartitionsAssigned(func(_ context.Context, _ *kgo.Client, m map[string][]int32) { + g.t.Logf("Assigned partitions 2: %v", m) + }), + ) + require.NoError(g.t, err) + return client +} + +func (g *testConsumerGroup) getAssignments() map[string][]int32 { + g.t.Helper() + ctx := context.Background() + groups, err := g.admin.DescribeGroups(ctx, g.groupName) + require.NoError(g.t, err) + + require.Len(g.t, groups, 1) + group := groups[g.groupName] + + assignments := make(map[string][]int32) + for _, member := range group.Members { + // Extract base member ID (without the suffix) + baseMemberID := member.ClientID + + c, ok := member.Assigned.AsConsumer() + require.True(g.t, ok) + for _, topic := range c.Topics { + if topic.Topic == "test-topic" { + assignments[baseMemberID] = topic.Partitions + } + } + } + return assignments +} + +func (g *testConsumerGroup) waitForStableAssignments(expectedMembers int, timeout time.Duration) map[string][]int32 { + g.t.Helper() + deadline := time.Now().Add(timeout) + var lastAssignments map[string][]int32 + + for time.Now().Before(deadline) { + assignments := g.getAssignments() + if len(assignments) == expectedMembers { + // Check if assignments are stable + if lastAssignments != nil { + stable := true + for id, parts := range assignments { + lastParts, ok := lastAssignments[id] + if !ok || !equalPartitions(parts, lastParts) { + stable = false + break + } + } + if stable { + return assignments + } + } + lastAssignments = assignments + } + time.Sleep(100 * time.Millisecond) + } + g.t.Fatalf("Timeout waiting for stable assignments with %d members", expectedMembers) + return nil +} + +func TestCooperativeActiveStickyBalancerE2E(t *testing.T) { + group := newTestConsumerGroup(t, 6) + + // Create first consumer + consumer1 := group.createConsumer("member-1") + defer consumer1.Close() + + // Wait for initial assignment + assignments := group.waitForStableAssignments(1, 5*time.Second) + t.Log("Initial state:") + t.Logf("Assignments: %v", assignments) + require.NotEmpty(t, assignments["member-1"]) + + // Create second consumer + consumer2 := group.createConsumer("member-2") + defer consumer2.Close() + + // Wait for rebalance + assignments = group.waitForStableAssignments(2, 5*time.Second) + t.Log("After consumer2 joins:") + t.Logf("Assignments: %v", assignments) + require.NotEmpty(t, assignments["member-1"]) + require.NotEmpty(t, assignments["member-2"]) + + // Create third consumer + consumer3 := group.createConsumer("member-3") + defer consumer3.Close() + + // Wait for rebalance + assignments = group.waitForStableAssignments(3, 5*time.Second) + t.Log("After consumer3 joins:") + t.Logf("Assignments: %v", assignments) + require.NotEmpty(t, assignments["member-1"]) + require.NotEmpty(t, assignments["member-2"]) + require.NotEmpty(t, assignments["member-3"]) + + // Close consumer2 to simulate it leaving + consumer2.Close() + + // Wait for rebalance + assignments = group.waitForStableAssignments(2, 5*time.Second) + t.Log("After consumer2 leaves:") + t.Logf("Assignments: %v", assignments) + require.NotEmpty(t, assignments["member-1"]) + require.NotEmpty(t, assignments["member-3"]) + + // Verify active partitions are evenly distributed + verifyActivePartitions := func(assignments map[string][]int32) int { + activeCount := 0 + for _, partitions := range assignments { + for _, p := range partitions { + if p <= 2 { // partitions 0,1,2 are active + activeCount++ + } + } + } + return activeCount + } + + active1 := verifyActivePartitions(map[string][]int32{"test-topic": assignments["member-1"]}) + active3 := verifyActivePartitions(map[string][]int32{"test-topic": assignments["member-3"]}) + t.Logf("Active partitions - Consumer1: %d, Consumer3: %d", active1, active3) + require.True(t, abs(active1-active3) <= 1, "Active partitions should be evenly distributed") +} + +// Helper to check if two partition slices are equal +func equalPartitions(a, b []int32) bool { + if len(a) != len(b) { + return false + } + sort.Slice(a, func(i, j int) bool { return a[i] < a[j] }) + sort.Slice(b, func(i, j int) bool { return b[i] < b[j] }) + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// Helper function to count active partitions in an assignment +func countActivePartitions(partitions []int32, activePartitions []int32) int { + count := 0 + for _, p := range partitions { + if contains(activePartitions, p) { + count++ + } + } + return count +} + +// Helper function to get absolute difference +func abs(x int) int { + if x < 0 { + return -x + } + return x +} + +// Helper function to check if a slice contains a value +func contains(slice []int32, val int32) bool { + for _, item := range slice { + if item == val { + return true + } + } + return false +} + +// Helper functions that create metadata for a member +func createMemberMetadata(t *testing.T, topics []string, currentAssignment map[string][]int32) []byte { + t.Helper() + meta := kmsg.NewConsumerMemberMetadata() + meta.Version = 3 + meta.Topics = topics + meta.Generation = 1 + + if currentAssignment != nil { + for topic, partitions := range currentAssignment { + owned := kmsg.NewConsumerMemberMetadataOwnedPartition() + owned.Topic = topic + owned.Partitions = partitions + meta.OwnedPartitions = append(meta.OwnedPartitions, owned) + } + sort.Slice(meta.OwnedPartitions, func(i, j int) bool { + return meta.OwnedPartitions[i].Topic < meta.OwnedPartitions[j].Topic + }) + } + + return meta.AppendTo(nil) +} + +// Helper function to extract partitions from a plan +func extractPartitions(t *testing.T, plan []kmsg.SyncGroupRequestGroupAssignment, memberID, topic string) []int32 { + t.Helper() + for _, assignment := range plan { + if assignment.MemberID == memberID { + var meta kmsg.ConsumerMemberAssignment + err := meta.ReadFrom(assignment.MemberAssignment) + require.NoError(t, err) + for _, topicAssignment := range meta.Topics { + if topicAssignment.Topic == topic { + return topicAssignment.Partitions + } + } + } + } + return nil +} diff --git a/pkg/kafka/partitionring/consumer/client.go b/pkg/kafka/partitionring/consumer/client.go new file mode 100644 index 0000000000000..2e218949f9094 --- /dev/null +++ b/pkg/kafka/partitionring/consumer/client.go @@ -0,0 +1,105 @@ +package consumer + +import ( + "sync" + "time" + + "github.com/go-kit/log" + "github.com/go-kit/log/level" + "github.com/grafana/dskit/ring" + "github.com/twmb/franz-go/pkg/kgo" + "github.com/twmb/franz-go/plugin/kprom" + + "github.com/grafana/loki/v3/pkg/kafka" + "github.com/grafana/loki/v3/pkg/kafka/client" +) + +type Client struct { + *kgo.Client + partitionRing ring.PartitionRingReader + logger log.Logger + stopCh chan struct{} + wg sync.WaitGroup +} + +// NewGroupClient creates a new Kafka consumer group client that participates in cooperative group consumption. +// It joins the specified consumer group and consumes messages from the configured Kafka topic. +// +// The client uses a cooperative-active-sticky balancing strategy which ensures active partitions are evenly +// distributed across group members while maintaining sticky assignments for optimal processing. Inactive partitions +// are still assigned and monitored but not processed, allowing quick activation when needed. Partition handoffs +// happen cooperatively to avoid stop-the-world rebalances. +// +// The client runs a background goroutine that monitors the partition ring for changes. When the set of active +// partitions changes (e.g. due to scaling or failures), it triggers a rebalance to ensure partitions are +// properly redistributed across the consumer group members. This maintains optimal processing as the active +// partition set evolves. +func NewGroupClient(kafkaCfg kafka.Config, partitionRing ring.PartitionRingReader, groupName string, metrics *kprom.Metrics, logger log.Logger, opts ...kgo.Opt) (*Client, error) { + defaultOpts := []kgo.Opt{ + kgo.ConsumerGroup(groupName), + kgo.ConsumeTopics(kafkaCfg.Topic), + kgo.Balancers(NewCooperativeActiveStickyBalancer(partitionRing)), + kgo.ConsumeResetOffset(kgo.NewOffset().AtStart()), + kgo.DisableAutoCommit(), + kgo.RebalanceTimeout(5 * time.Minute), + } + + // Combine remaining options with our defaults + allOpts := append(defaultOpts, opts...) + + client, err := client.NewReaderClient(kafkaCfg, metrics, logger, allOpts...) + if err != nil { + return nil, err + } + + c := &Client{ + Client: client, + partitionRing: partitionRing, + stopCh: make(chan struct{}), + logger: logger, + } + + // Start the partition monitor goroutine + c.wg.Add(1) + go c.monitorPartitions() + + return c, nil +} + +func (c *Client) monitorPartitions() { + defer c.wg.Done() + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + // Get initial partition count from the ring + lastPartitionCount := c.partitionRing.PartitionRing().PartitionsCount() + + for { + select { + case <-c.stopCh: + return + case <-ticker.C: + // Get current partition count from the ring + currentPartitionCount := c.partitionRing.PartitionRing().PartitionsCount() + if currentPartitionCount != lastPartitionCount { + level.Info(c.logger).Log( + "msg", "partition count changed, triggering rebalance", + "previous_count", lastPartitionCount, + "current_count", currentPartitionCount, + ) + // Trigger a rebalance to update partition assignments + // All consumers trigger the rebalance, but only the group leader will actually perform it + // For non-leader consumers, triggering the rebalance has no effect + c.ForceRebalance() + lastPartitionCount = currentPartitionCount + } + } + } +} + +func (c *Client) Close() { + close(c.stopCh) // Signal the monitor goroutine to stop + c.wg.Wait() // Wait for the monitor goroutine to exit + c.Client.Close() // Close the underlying client +} diff --git a/pkg/kafka/partitionring/consumer/client_test.go b/pkg/kafka/partitionring/consumer/client_test.go new file mode 100644 index 0000000000000..897d33a884a17 --- /dev/null +++ b/pkg/kafka/partitionring/consumer/client_test.go @@ -0,0 +1,321 @@ +package consumer + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/go-kit/log" + "github.com/stretchr/testify/require" + "github.com/twmb/franz-go/pkg/kfake" + "github.com/twmb/franz-go/pkg/kgo" + "github.com/twmb/franz-go/plugin/kprom" + + "github.com/grafana/loki/v3/pkg/kafka" +) + +func TestPartitionMonitorRebalancing(t *testing.T) { + // Create a fake cluster with initial partitions + const totalPartitions = 6 + cluster := kfake.MustCluster( + kfake.NumBrokers(2), + kfake.SeedTopics(totalPartitions, "test-topic"), + ) + defer cluster.Close() + + addrs := cluster.ListenAddrs() + require.NotEmpty(t, addrs) + + // Create a producer client + producer, err := kgo.NewClient( + kgo.SeedBrokers(addrs...), + ) + require.NoError(t, err) + defer producer.Close() + + // Create mock ring with initial active partitions + mockRing := &mockPartitionRing{partitionIDs: []int32{0, 1}} + mockReader := &mockPartitionRingReader{ring: mockRing} + + // Track processed records to verify continuity + type recordKey struct { + partition int32 + offset int64 + } + processedRecords := sync.Map{} + var processingWg sync.WaitGroup + + // Create two consumers using our Client wrapper + createConsumer := func(id string) *Client { + cfg := kafka.Config{ + Address: addrs[0], + Topic: "test-topic", + } + + // Track partition assignments for this consumer + var assignedPartitions sync.Map + var partitionsLock sync.Mutex + + client, err := NewGroupClient(cfg, mockReader, "test-group", kprom.NewMetrics("foo"), log.NewNopLogger(), + kgo.ClientID(id), + kgo.OnPartitionsAssigned(func(_ context.Context, _ *kgo.Client, assigned map[string][]int32) { + partitionsLock.Lock() + defer partitionsLock.Unlock() + t.Logf("%s assigned partitions: %v", id, assigned["test-topic"]) + for _, p := range assigned["test-topic"] { + assignedPartitions.Store(p, struct{}{}) + } + }), + kgo.OnPartitionsRevoked(func(_ context.Context, _ *kgo.Client, revoked map[string][]int32) { + partitionsLock.Lock() + defer partitionsLock.Unlock() + t.Logf("%s revoked partitions: %v", id, revoked["test-topic"]) + for _, p := range revoked["test-topic"] { + assignedPartitions.Delete(p) + } + // Wait for in-flight processing before revoking + t.Logf("%s waiting for in-flight processing before revoke...", id) + processingWg.Wait() + t.Logf("%s completed revoke", id) + }), + ) + require.NoError(t, err) + + // Start consuming in a goroutine + go func() { + for { + ctx := context.Background() + records := client.PollFetches(ctx) + if records == nil { + return // client closed + } + + if len(records.Records()) > 0 { + processingWg.Add(1) + go func(fetches kgo.Fetches) { + defer processingWg.Done() + + // Verify we only got records for partitions we own + for _, record := range fetches.Records() { + _, ok := assignedPartitions.Load(record.Partition) + require.True(t, ok, "%s received record for unassigned partition %d", id, record.Partition) + + // Track this record + key := recordKey{record.Partition, record.Offset} + if prev, loaded := processedRecords.LoadOrStore(key, id); loaded { + t.Errorf("Record at partition %d offset %d processed twice! First by %v, then by %v", + key.partition, key.offset, prev, id) + } + } + + // Commit the records + if err := client.CommitRecords(context.Background(), fetches.Records()...); err != nil { + t.Logf("%s error committing: %v", id, err) + } + }(records) + } + } + }() + + return client + } + + // Create two consumers + consumer1 := createConsumer("consumer1") + defer consumer1.Close() + consumer2 := createConsumer("consumer2") + defer consumer2.Close() + + // Start producing records to all partitions + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + i := 0 + for ctx.Err() == nil { + for partition := 0; partition < totalPartitions; partition++ { + err := producer.ProduceSync(ctx, &kgo.Record{ + Topic: "test-topic", + Partition: int32(partition), + Key: []byte(fmt.Sprintf("key-%d", i)), + Value: []byte(fmt.Sprintf("value-%d", i)), + }).FirstErr() + if err != nil { + t.Logf("Error producing to partition %d: %v", partition, err) + return + } + } + i++ + time.Sleep(50 * time.Millisecond) + } + }() + + // Let the initial setup stabilize + time.Sleep(2 * time.Second) + + // Change the active partitions in the ring + t.Log("Changing active partitions from [0,1] to [0,1,2]") + mockRing.partitionIDs = []int32{0, 1, 2} + + // Wait for rebalancing to occur and stabilize + time.Sleep(7 * time.Second) + + // Change active partitions again + t.Log("Changing active partitions from [0,1,2] to [0,1,2,3]") + mockRing.partitionIDs = []int32{0, 1, 2, 3} + + // Wait for final rebalancing + time.Sleep(7 * time.Second) + + // Stop producing + cancel() + + // Wait for any in-flight processing to complete + processingWg.Wait() + + // Verify no duplicates were processed and count records per partition + partitionCounts := make(map[int32]int) + partitionConsumers := make(map[int32]map[string]struct{}) + processedRecords.Range(func(key, value interface{}) bool { + k := key.(recordKey) + partitionCounts[k.partition]++ + if _, ok := partitionConsumers[k.partition]; !ok { + partitionConsumers[k.partition] = make(map[string]struct{}) + } + partitionConsumers[k.partition][value.(string)] = struct{}{} + return true + }) + + // Log partition processing stats + for partition, count := range partitionCounts { + t.Logf("Partition %d: processed %d records", partition, count) + require.Greater(t, count, 0, "Expected records from partition %d", partition) + } + + for partition, consumers := range partitionConsumers { + t.Logf("Partition %d: processed by %v", partition, consumers) + } +} + +func TestPartitionContinuityDuringRebalance(t *testing.T) { + // Create a fake cluster with initial partitions + const totalPartitions = 4 + cluster := kfake.MustCluster( + kfake.NumBrokers(2), + kfake.SeedTopics(totalPartitions, "test-topic"), + ) + defer cluster.Close() + + addrs := cluster.ListenAddrs() + require.NotEmpty(t, addrs) + + // Create mock ring with initial active partitions + mockRing := &mockPartitionRing{partitionIDs: []int32{0, 1}} + mockReader := &mockPartitionRingReader{ring: mockRing} + + // Track offsets for partition 0 to verify continuous reading + var lastOffset int64 + var offsetMu sync.Mutex + + createConsumer := func(id string) *Client { + cfg := kafka.Config{ + Address: addrs[0], + Topic: "test-topic", + } + + client, err := NewGroupClient(cfg, mockReader, "test-group", kprom.NewMetrics("foo"), log.NewNopLogger(), + kgo.ClientID(id), + kgo.OnPartitionsAssigned(func(_ context.Context, _ *kgo.Client, assigned map[string][]int32) { + t.Logf("%s assigned partitions: %v", id, assigned["test-topic"]) + }), + kgo.OnPartitionsRevoked(func(_ context.Context, _ *kgo.Client, revoked map[string][]int32) { + t.Logf("%s revoked partitions: %v", id, revoked["test-topic"]) + }), + ) + require.NoError(t, err) + + go func() { + for { + ctx := context.Background() + records := client.PollFetches(ctx) + if records == nil { + return // client closed + } + + // Only verify partition 0's offset continuity + for _, record := range records.Records() { + if record.Partition == 0 { + offsetMu.Lock() + if lastOffset > 0 { + require.Equal(t, lastOffset+1, record.Offset, + "Gap detected in partition 0: expected offset %d, got %d", + lastOffset+1, record.Offset) + } + lastOffset = record.Offset + offsetMu.Unlock() + t.Logf("%s read offset %d from partition 0", id, record.Offset) + } + } + } + }() + + return client + } + + // Create producer + producer, err := kgo.NewClient(kgo.SeedBrokers(addrs...)) + require.NoError(t, err) + defer producer.Close() + + // Start with one consumer + consumer1 := createConsumer("consumer1") + defer consumer1.Close() + + // Start producing records + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + i := 0 + for ctx.Err() == nil { + for partition := 0; partition < totalPartitions; partition++ { + err := producer.ProduceSync(ctx, &kgo.Record{ + Topic: "test-topic", + Partition: int32(partition), + Key: []byte(fmt.Sprintf("key-%d", i)), + Value: []byte(fmt.Sprintf("value-%d", i)), + }).FirstErr() + if err != nil { + t.Logf("Error producing to partition %d: %v", partition, err) + return + } + } + i++ + time.Sleep(50 * time.Millisecond) + } + }() + + // Let initial setup stabilize and verify consumer1 is reading + time.Sleep(2 * time.Second) + require.Greater(t, lastOffset, int64(0), "consumer1 should have read some records from partition 0") + initialOffset := lastOffset + + // Add second consumer and change active partitions + t.Log("Adding consumer2 and changing active partitions from [0,1] to [0,1,2]") + consumer2 := createConsumer("consumer2") + defer consumer2.Close() + mockRing.partitionIDs = []int32{0, 1, 2} + + // Let it run for a while + time.Sleep(5 * time.Second) + + // Verify partition 0 continued reading without reset + require.Greater(t, lastOffset, initialOffset, + "partition 0 should have continued reading from last offset (%d)", initialOffset) + t.Logf("Partition 0 read from offset %d to %d during rebalance", initialOffset, lastOffset) + + // Stop everything + cancel() +}