From a5ea567f1b7d20b5d42ce97ff54944cb7b8a0e5b Mon Sep 17 00:00:00 2001 From: musa-asad Date: Thu, 6 Mar 2025 00:13:21 -0500 Subject: [PATCH] initial --- extension/k8smetadata/README.md | 0 extension/k8smetadata/config.go | 9 + extension/k8smetadata/extension.go | 39 +++ extension/k8smetadata/factory.go | 40 +++ .../k8sclient/endpointslicewatcher.go | 293 +++++++++++++++++ .../k8sclient/endpointslicewatcher_test.go | 296 ++++++++++++++++++ .../k8sCommon/k8sclient/kubernetes_utils.go | 231 ++++++++++++++ .../k8sclient/kubernetes_utils_test.go | 258 +++++++++++++++ 8 files changed, 1166 insertions(+) create mode 100644 extension/k8smetadata/README.md create mode 100644 extension/k8smetadata/config.go create mode 100644 extension/k8smetadata/extension.go create mode 100644 extension/k8smetadata/factory.go create mode 100644 internal/k8sCommon/k8sclient/endpointslicewatcher.go create mode 100644 internal/k8sCommon/k8sclient/endpointslicewatcher_test.go create mode 100644 internal/k8sCommon/k8sclient/kubernetes_utils.go create mode 100644 internal/k8sCommon/k8sclient/kubernetes_utils_test.go diff --git a/extension/k8smetadata/README.md b/extension/k8smetadata/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/extension/k8smetadata/config.go b/extension/k8smetadata/config.go new file mode 100644 index 0000000000..b961b1b141 --- /dev/null +++ b/extension/k8smetadata/config.go @@ -0,0 +1,9 @@ +package k8smetadata + +import ( + "go.opentelemetry.io/collector/component" +) + +type Config struct {} + +var _ component.Config = (*Config)(nil) diff --git a/extension/k8smetadata/extension.go b/extension/k8smetadata/extension.go new file mode 100644 index 0000000000..defd629c29 --- /dev/null +++ b/extension/k8smetadata/extension.go @@ -0,0 +1,39 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package k8smetadata + +import ( + "context" + "github.com/aws/amazon-cloudwatch-agent/internal/k8sCommon/k8sclient" + "go.opentelemetry.io/collector/component" + "go.opentelemetry.io/collector/extension" + "go.uber.org/zap" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/cache" + "sync" +) + +type KubernetesMetadata struct { + logger *zap.Logger + config *Config + + mu sync.Mutex + + clientset kubernetes.Interface + sharedInformerFactory cache.SharedInformer + + ipToPodMetadata *sync.Map + + endpointSliceWatcher *k8sclient.EndpointSliceWatcher +} + +var _ extension.Extension = (*KubernetesMetadata)(nil) + +func (e *KubernetesMetadata) Start(ctx context.Context, host component.Host) error { + return nil +} + +func (e *KubernetesMetadata) Shutdown(_ context.Context) error { + return nil +} diff --git a/extension/k8smetadata/factory.go b/extension/k8smetadata/factory.go new file mode 100644 index 0000000000..9fdef377b1 --- /dev/null +++ b/extension/k8smetadata/factory.go @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package k8smetadata + +import ( + "context" + "go.opentelemetry.io/collector/component" + "go.opentelemetry.io/collector/extension" + "sync" +) + +var ( + TypeStr, _ = component.NewType("k8smetadata") + kubernetesMetadata *KubernetesMetadata + mutex sync.RWMutex +) + +func NewFactory() extension.Factory { + return extension.NewFactory( + TypeStr, + createDefaultConfig, + createExtension, + component.StabilityLevelAlpha, + ) +} + +func createDefaultConfig() component.Config { + return &Config{} +} + +func createExtension(_ context.Context, settings extension.Settings, cfg component.Config) (extension.Extension, error) { + mutex.Lock() + defer mutex.Unlock() + kubernetesMetadata = &KubernetesMetadata{ + logger: settings.Logger, + config: cfg.(*Config), + } + return kubernetesMetadata, nil +} diff --git a/internal/k8sCommon/k8sclient/endpointslicewatcher.go b/internal/k8sCommon/k8sclient/endpointslicewatcher.go new file mode 100644 index 0000000000..b54cd4b376 --- /dev/null +++ b/internal/k8sCommon/k8sclient/endpointslicewatcher.go @@ -0,0 +1,293 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package k8sclient + +import ( + "fmt" + "sync" + + "go.uber.org/zap" + discv1 "k8s.io/api/discovery/v1" + "k8s.io/client-go/informers" + "k8s.io/client-go/tools/cache" +) + +// EndpointSliceWatcher watches EndpointSlices and builds: +// 1. ip/ip:port -> "workload@namespace" +// 2. service@namespace -> "workload@namespace" +type EndpointSliceWatcher struct { + logger *zap.Logger + informer cache.SharedIndexInformer + ipToWorkload *sync.Map // key: "ip" or "ip:port", val: "workload@ns" + serviceToWorkload *sync.Map // key: "service@namespace", val: "workload@ns" + + // For bookkeeping, so we can remove old mappings upon EndpointSlice deletion + sliceToKeysMap sync.Map // map[sliceUID string] -> []string of keys we inserted, which are "ip", "ip:port", or "service@namespace" + deleter Deleter +} + +// kvPair holds one mapping from key -> value. The isService flag +// indicates whether this key is for a Service or for an IP/IP:port. +type kvPair struct { + key string // key: "ip" or "ip:port" or "service@namespace" + value string // value: "workload@namespace" + isService bool // true if key = "service@namespace" +} + +// newEndpointSliceWatcher creates an EndpointSlice watcher for the new approach (when USE_LIST_POD=false). +func newEndpointSliceWatcher( + logger *zap.Logger, + factory informers.SharedInformerFactory, + deleter Deleter, +) *EndpointSliceWatcher { + + esInformer := factory.Discovery().V1().EndpointSlices().Informer() + err := esInformer.SetTransform(minimizeEndpointSlice) + if err != nil { + logger.Error("failed to minimize Service objects", zap.Error(err)) + } + + return &EndpointSliceWatcher{ + logger: logger, + informer: esInformer, + ipToWorkload: &sync.Map{}, + serviceToWorkload: &sync.Map{}, + deleter: deleter, + } +} + +// run starts the endpointSliceWatcher. +func (w *EndpointSliceWatcher) Run(stopCh chan struct{}) { + w.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + w.handleSliceAdd(obj) + }, + UpdateFunc: func(oldObj, newObj interface{}) { + w.handleSliceUpdate(newObj, oldObj) + }, + DeleteFunc: func(obj interface{}) { + w.handleSliceDelete(obj) + }, + }) + go w.informer.Run(stopCh) +} + +func (w *EndpointSliceWatcher) waitForCacheSync(stopCh chan struct{}) { + if !cache.WaitForNamedCacheSync("endpointSliceWatcher", stopCh, w.informer.HasSynced) { + w.logger.Fatal("timed out waiting for endpointSliceWatcher cache to sync") + } + w.logger.Info("endpointSliceWatcher: Cache synced") +} + +// extractEndpointSliceKeyValuePairs computes the relevant mappings from an EndpointSlice. +// +// It returns a list of kvPair: +// - All IP and IP:port keys (isService=false) -> "workload@ns" +// - The Service name key (isService=true) -> first "workload@ns" found +// +// This function does NOT modify ipToWorkload or serviceToWorkload. It's purely for computing +// the pairs, so it can be reused by both add and update methods. +func (w *EndpointSliceWatcher) extractEndpointSliceKeyValuePairs(slice *discv1.EndpointSlice) []kvPair { + var pairs []kvPair + + isFirstPod := true + svcName := slice.Labels["kubernetes.io/service-name"] + + for _, endpoint := range slice.Endpoints { + if endpoint.TargetRef != nil { + if endpoint.TargetRef.Kind != "Pod" { + continue + } + + podName := endpoint.TargetRef.Name + ns := endpoint.TargetRef.Namespace + + derivedWorkload := inferWorkloadName(podName, svcName) + if derivedWorkload == "" { + w.logger.Warn("failed to infer workload name from Pod name", zap.String("podName", podName)) + continue + } + fullWl := derivedWorkload + "@" + ns + + // Build IP and IP:port pairs + for _, addr := range endpoint.Addresses { + // "ip" -> "workload@namespace" + pairs = append(pairs, kvPair{ + key: addr, + value: fullWl, + isService: false, + }) + + // "ip:port" -> "workload@namespace" for each port + for _, portDef := range slice.Ports { + if portDef.Port != nil { + ipPort := fmt.Sprintf("%s:%d", addr, *portDef.Port) + pairs = append(pairs, kvPair{ + key: ipPort, + value: fullWl, + isService: false, + }) + } + } + } + + // Build service name -> "workload@namespace" pair from the first pod + if isFirstPod { + isFirstPod = false + if svcName != "" { + pairs = append(pairs, kvPair{ + key: svcName + "@" + ns, + value: fullWl, + isService: true, + }) + } + } + } + + } + + return pairs +} + +// handleSliceAdd handles a new EndpointSlice that wasn't seen before. +// It computes all keys and directly stores them. Then it records those keys +// in sliceToKeysMap so that we can remove them later upon deletion. +func (w *EndpointSliceWatcher) handleSliceAdd(obj interface{}) { + newSlice := obj.(*discv1.EndpointSlice) + sliceUID := string(newSlice.UID) + + // Compute all key-value pairs for this new slice + pairs := w.extractEndpointSliceKeyValuePairs(newSlice) + + // Insert them into our ipToWorkload / serviceToWorkload, and track the keys. + keys := make([]string, 0, len(pairs)) + for _, kv := range pairs { + if kv.isService { + w.serviceToWorkload.Store(kv.key, kv.value) + } else { + w.ipToWorkload.Store(kv.key, kv.value) + } + keys = append(keys, kv.key) + } + + // Save these keys so we can remove them on delete + w.sliceToKeysMap.Store(sliceUID, keys) +} + +// handleSliceUpdate handles an update from oldSlice -> newSlice. +// Instead of blindly removing all old keys and adding new ones, it diffs them: +// - remove only keys that no longer exist, +// - add only new keys that didn't exist before, +// - keep those that haven't changed. +func (w *EndpointSliceWatcher) handleSliceUpdate(oldObj, newObj interface{}) { + oldSlice := oldObj.(*discv1.EndpointSlice) + newSlice := newObj.(*discv1.EndpointSlice) + + oldUID := string(oldSlice.UID) + newUID := string(newSlice.UID) + + // 1) Fetch old keys from sliceToKeysMap (if present). + var oldKeys []string + if val, ok := w.sliceToKeysMap.Load(oldUID); ok { + oldKeys = val.([]string) + } + + // 2) Compute fresh pairs (and thus keys) from the new slice. + newPairs := w.extractEndpointSliceKeyValuePairs(newSlice) + var newKeys []string + for _, kv := range newPairs { + newKeys = append(newKeys, kv.key) + } + + // Convert oldKeys/newKeys to sets for easy diff + oldKeysSet := make(map[string]struct{}, len(oldKeys)) + for _, k := range oldKeys { + oldKeysSet[k] = struct{}{} + } + newKeysSet := make(map[string]struct{}, len(newKeys)) + for _, k := range newKeys { + newKeysSet[k] = struct{}{} + } + + // 3) For each key in oldKeys that doesn't exist in newKeys, remove it + for k := range oldKeysSet { + if _, stillPresent := newKeysSet[k]; !stillPresent { + w.deleter.DeleteWithDelay(w.ipToWorkload, k) + w.deleter.DeleteWithDelay(w.serviceToWorkload, k) + } + } + + // 4) For each key in newKeys that wasn't in oldKeys, we need to store it + // in the appropriate sync.Map. We'll look up the value from newPairs. + for _, kv := range newPairs { + if _, alreadyHad := oldKeysSet[kv.key]; !alreadyHad { + if kv.isService { + w.serviceToWorkload.Store(kv.key, kv.value) + } else { + w.ipToWorkload.Store(kv.key, kv.value) + } + } + } + + // 5) Update sliceToKeysMap for the new slice UID + // (Often the UID doesn't change across updates, but we'll handle it properly.) + w.sliceToKeysMap.Delete(oldUID) + w.sliceToKeysMap.Store(newUID, newKeys) +} + +// handleSliceDelete removes any IP->workload or service->workload keys that were created by this slice. +func (w *EndpointSliceWatcher) handleSliceDelete(obj interface{}) { + slice := obj.(*discv1.EndpointSlice) + w.removeSliceKeys(slice) +} + +func (w *EndpointSliceWatcher) removeSliceKeys(slice *discv1.EndpointSlice) { + sliceUID := string(slice.UID) + val, ok := w.sliceToKeysMap.Load(sliceUID) + if !ok { + return + } + + keys := val.([]string) + for _, k := range keys { + w.deleter.DeleteWithDelay(w.ipToWorkload, k) + w.deleter.DeleteWithDelay(w.serviceToWorkload, k) + } + w.sliceToKeysMap.Delete(sliceUID) +} + +// minimizeEndpointSlice removes fields that are not required by our mapping logic, +// retaining only the minimal set of fields needed (ObjectMeta.Name, Namespace, UID, Labels, +// Endpoints (with their Addresses and TargetRef) and Ports). +func minimizeEndpointSlice(obj interface{}) (interface{}, error) { + eps, ok := obj.(*discv1.EndpointSlice) + if !ok { + return obj, fmt.Errorf("object is not an EndpointSlice") + } + + // Minimize metadata: we only really need Name, Namespace, UID and Labels. + eps.Annotations = nil + eps.ManagedFields = nil + eps.Finalizers = nil + + // The watcher only uses: + // - eps.Labels["kubernetes.io/service-name"] + // - eps.Namespace (from metadata) + // - eps.UID (from metadata) + // - eps.Endpoints: for each endpoint, its Addresses and TargetRef. + // - eps.Ports: each port's Port (and optionally Name/Protocol) + // + // For each endpoint, clear fields that we don’t use. + for i := range eps.Endpoints { + // We only need Addresses and TargetRef. Hostname, NodeName, and Zone are not used. + eps.Endpoints[i].Hostname = nil + eps.Endpoints[i].NodeName = nil + eps.Endpoints[i].Zone = nil + eps.Endpoints[i].DeprecatedTopology = nil + eps.Endpoints[i].Hints = nil + } + + // No transformation is needed for eps.Ports because we use them directly. + return eps, nil +} diff --git a/internal/k8sCommon/k8sclient/endpointslicewatcher_test.go b/internal/k8sCommon/k8sclient/endpointslicewatcher_test.go new file mode 100644 index 0000000000..2a51289f84 --- /dev/null +++ b/internal/k8sCommon/k8sclient/endpointslicewatcher_test.go @@ -0,0 +1,296 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package k8sclient + +import ( + "fmt" + "reflect" + "sort" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + v1 "k8s.io/api/core/v1" + discv1 "k8s.io/api/discovery/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +func newEndpointSliceWatcherForTest() *EndpointSliceWatcher { + return &EndpointSliceWatcher{ + logger: zap.NewNop(), + ipToWorkload: &sync.Map{}, + serviceToWorkload: &sync.Map{}, + deleter: mockDeleter, + } +} + +// createTestEndpointSlice is a helper to build a minimal EndpointSlice. +// The slice will have one Endpoint (with its TargetRef) and a list of Ports. +// svcName is stored in the Labels (key "kubernetes.io/service-name") if non-empty. +func createTestEndpointSlice(uid, namespace, svcName, podName string, addresses []string, portNumbers []int32) *discv1.EndpointSlice { + // Build the port list. + var ports []discv1.EndpointPort + for i, p := range portNumbers { + portVal := p // need a pointer + name := fmt.Sprintf("port-%d", i) + protocol := v1.ProtocolTCP + ports = append(ports, discv1.EndpointPort{ + Name: &name, + Protocol: &protocol, + Port: &portVal, + }) + } + + // Build a single endpoint with the given addresses and a TargetRef. + endpoint := discv1.Endpoint{ + Addresses: addresses, + TargetRef: &v1.ObjectReference{ + Kind: "Pod", + Name: podName, + Namespace: namespace, + }, + } + + labels := map[string]string{} + if svcName != "" { + labels["kubernetes.io/service-name"] = svcName + } + + return &discv1.EndpointSlice{ + ObjectMeta: metav1.ObjectMeta{ + UID: types.UID(uid), + Namespace: namespace, + Labels: labels, + }, + Endpoints: []discv1.Endpoint{endpoint}, + Ports: ports, + } +} + +// --- Tests --- + +// TestEndpointSliceAddition verifies that when a new EndpointSlice is added, +// the appropriate keys are inserted into the maps. +func TestEndpointSliceAddition(t *testing.T) { + watcher := newEndpointSliceWatcherForTest() + + // Create a test EndpointSlice: + // - UID: "uid-1", Namespace: "testns" + // - Labels: "kubernetes.io/service-name" = "mysvc" + // - One Endpoint with TargetRef.Kind "Pod", Name "workload-69dww", Namespace "testns" + // - Endpoint.Addresses: ["1.2.3.4"] + // - One Port with value 80. + slice := createTestEndpointSlice("uid-1", "testns", "mysvc", "workload-69dww", []string{"1.2.3.4"}, []int32{80}) + + // Call the add handler. + watcher.handleSliceAdd(slice) + + // The dummy inferWorkloadName returns "workload", so full workload becomes "workload@testns" + expectedVal := "workload@testns" + + // We expect the following keys: + // - For the endpoint: "1.2.3.4" and "1.2.3.4:80" + // - From the service label: "mysvc@testns" + var expectedIPKeys = []string{"1.2.3.4", "1.2.3.4:80"} + var expectedSvcKeys = []string{"mysvc@testns"} + + // Verify ipToWorkload. + for _, key := range expectedIPKeys { + val, ok := watcher.ipToWorkload.Load(key) + assert.True(t, ok, "expected ipToWorkload key %s", key) + assert.Equal(t, expectedVal, val, "ipToWorkload[%s] mismatch", key) + } + + // Verify serviceToWorkload. + for _, key := range expectedSvcKeys { + val, ok := watcher.serviceToWorkload.Load(key) + assert.True(t, ok, "expected serviceToWorkload key %s", key) + assert.Equal(t, expectedVal, val, "serviceToWorkload[%s] mismatch", key) + } + + // Verify that sliceToKeysMap recorded all keys. + val, ok := watcher.sliceToKeysMap.Load(string(slice.UID)) + assert.True(t, ok, "expected sliceToKeysMap to contain UID %s", slice.UID) + keysIface := val.([]string) + // Sort for comparison. + sort.Strings(keysIface) + allExpected := append(expectedIPKeys, expectedSvcKeys...) + sort.Strings(allExpected) + assert.Equal(t, allExpected, keysIface, "sliceToKeysMap keys mismatch") +} + +// TestEndpointSliceDeletion verifies that when an EndpointSlice is deleted, +// all keys that were added are removed. +func TestEndpointSliceDeletion(t *testing.T) { + watcher := newEndpointSliceWatcherForTest() + + // Create a test EndpointSlice (same as addition test). + slice := createTestEndpointSlice("uid-1", "testns", "mysvc", "workload-76977669dc-lwx64", []string{"1.2.3.4"}, []int32{80}) + watcher.handleSliceAdd(slice) + + // Now call deletion. + watcher.handleSliceDelete(slice) + + // Verify that the keys are removed from ipToWorkload. + removedKeys := []string{"1.2.3.4", "1.2.3.4:80", "mysvc@testns"} + for _, key := range removedKeys { + _, ok := watcher.ipToWorkload.Load(key) + _, okSvc := watcher.serviceToWorkload.Load(key) + assert.False(t, ok, "expected ipToWorkload key %s to be deleted", key) + assert.False(t, okSvc, "expected serviceToWorkload key %s to be deleted", key) + } + + // Also verify that sliceToKeysMap no longer contains an entry. + _, ok := watcher.sliceToKeysMap.Load(string(slice.UID)) + assert.False(t, ok, "expected sliceToKeysMap entry for UID %s to be deleted", slice.UID) +} + +// TestEndpointSliceUpdate verifies that on updates, keys are added and/or removed as appropriate. +func TestEndpointSliceUpdate(t *testing.T) { + // --- Subtest: Complete change (no overlap) --- + t.Run("complete change", func(t *testing.T) { + watcher := newEndpointSliceWatcherForTest() + + // Old slice: + // UID "uid-2", Namespace "testns", svc label "mysvc", + // One endpoint with TargetRef Name "workload-75d9d5968d-fx8px", Addresses ["1.2.3.4"], Port 80. + oldSlice := createTestEndpointSlice("uid-2", "testns", "mysvc", "workload-75d9d5968d-fx8px", []string{"1.2.3.4"}, []int32{80}) + watcher.handleSliceAdd(oldSlice) + + // New slice: same UID, but svc label changed to "othersvc" + // and a different endpoint: TargetRef Name "workload-6d9b7f8597-wbvxn", Addresses ["1.2.3.5"], Port 443. + newSlice := createTestEndpointSlice("uid-2", "testns", "othersvc", "workload-6d9b7f8597-wbvxn", []string{"1.2.3.5"}, []int32{443}) + + // Call update handler. + watcher.handleSliceUpdate(oldSlice, newSlice) + + expectedVal := "workload@testns" + + // Old keys that should be removed: + // "1.2.3.4" and "1.2.3.4:80" and service key "mysvc@testns" + removedKeys := []string{"1.2.3.4", "1.2.3.4:80", "mysvc@testns"} + for _, key := range removedKeys { + _, ok := watcher.ipToWorkload.Load(key) + _, okSvc := watcher.serviceToWorkload.Load(key) + assert.False(t, ok, "expected ipToWorkload key %s to be removed", key) + assert.False(t, okSvc, "expected serviceToWorkload key %s to be removed", key) + } + + // New keys that should be added: + // "1.2.3.5", "1.2.3.5:443", and service key "othersvc@testns" + addedKeys := []string{"1.2.3.5", "1.2.3.5:443", "othersvc@testns"} + for _, key := range addedKeys { + var val interface{} + var ok bool + // For service key, check serviceToWorkload; for others, check ipToWorkload. + if key == "othersvc@testns" { + val, ok = watcher.serviceToWorkload.Load(key) + } else { + val, ok = watcher.ipToWorkload.Load(key) + } + assert.True(t, ok, "expected key %s to be added", key) + assert.Equal(t, expectedVal, val, "value for key %s mismatch", key) + } + + // Check that sliceToKeysMap now contains exactly the new keys. + val, ok := watcher.sliceToKeysMap.Load(string(newSlice.UID)) + assert.True(t, ok, "expected sliceToKeysMap entry for UID %s", newSlice.UID) + gotKeys := val.([]string) + sort.Strings(gotKeys) + expectedKeys := []string{"1.2.3.5", "1.2.3.5:443", "othersvc@testns"} + sort.Strings(expectedKeys) + assert.True(t, reflect.DeepEqual(expectedKeys, gotKeys), "sliceToKeysMap keys mismatch, got: %v, want: %v", gotKeys, expectedKeys) + }) + + // --- Subtest: Partial overlap --- + t.Run("partial overlap", func(t *testing.T) { + watcher := newEndpointSliceWatcherForTest() + + // Old slice: UID "uid-3", Namespace "testns", svc label "mysvc", + // with one endpoint: TargetRef "workload-6d9b7f8597-b5l2j", Addresses ["1.2.3.4"], Port 80. + oldSlice := createTestEndpointSlice("uid-3", "testns", "mysvc", "workload-6d9b7f8597-b5l2j", []string{"1.2.3.4"}, []int32{80}) + watcher.handleSliceAdd(oldSlice) + + // New slice: same UID, same svc label ("mysvc") but now two endpoints. + // First endpoint: same as before: Addresses ["1.2.3.4"], Port 80. + // Second endpoint: Addresses ["1.2.3.5"], Port 80. + // (Since svc label remains, the service key "mysvc@testns" remains the same.) + // We expect the new keys to be the union of: + // From first endpoint: "1.2.3.4", "1.2.3.4:80" + // From second endpoint: "1.2.3.5", "1.2.3.5:80" + // And the service key "mysvc@testns". + name := "port-0" + protocol := v1.ProtocolTCP + newSlice := &discv1.EndpointSlice{ + ObjectMeta: metav1.ObjectMeta{ + UID: "uid-3", // same UID + Namespace: "testns", + Labels: map[string]string{ + "kubernetes.io/service-name": "mysvc", + }, + }, + // Two endpoints. + Endpoints: []discv1.Endpoint{ + { + Addresses: []string{"1.2.3.4"}, + TargetRef: &v1.ObjectReference{ + Kind: "Pod", + Name: "workload-6d9b7f8597-b5l2j", + Namespace: "testns", + }, + }, + { + Addresses: []string{"1.2.3.5"}, + TargetRef: &v1.ObjectReference{ + Kind: "Pod", + Name: "workload-6d9b7f8597-fx8px", + Namespace: "testns", + }, + }, + }, + // Single port: 80. + Ports: []discv1.EndpointPort{ + { + Name: &name, + Protocol: &protocol, + Port: func() *int32 { p := int32(80); return &p }(), + }, + }, + } + + // Call update handler. + watcher.handleSliceUpdate(oldSlice, newSlice) + + expectedVal := "workload@testns" + // Expected keys now: + // From endpoint 1: "1.2.3.4", "1.2.3.4:80" + // From endpoint 2: "1.2.3.5", "1.2.3.5:80" + // And service key: "mysvc@testns" + expectedKeysIP := []string{"1.2.3.4", "1.2.3.4:80", "1.2.3.5", "1.2.3.5:80"} + expectedKeysSvc := []string{"mysvc@testns"} + + // Verify that all expected keys are present. + for _, key := range expectedKeysIP { + val, ok := watcher.ipToWorkload.Load(key) + assert.True(t, ok, "expected ipToWorkload key %s", key) + assert.Equal(t, expectedVal, val, "ipToWorkload[%s] mismatch", key) + } + for _, key := range expectedKeysSvc { + val, ok := watcher.serviceToWorkload.Load(key) + assert.True(t, ok, "expected serviceToWorkload key %s", key) + assert.Equal(t, expectedVal, val, "serviceToWorkload[%s] mismatch", key) + } + + // And check that sliceToKeysMap contains the union of the keys. + val, ok := watcher.sliceToKeysMap.Load("uid-3") + assert.True(t, ok, "expected sliceToKeysMap to contain uid-3") + gotKeys := val.([]string) + allExpected := append(expectedKeysIP, expectedKeysSvc...) + sort.Strings(gotKeys) + sort.Strings(allExpected) + assert.True(t, reflect.DeepEqual(allExpected, gotKeys), "sliceToKeysMap keys mismatch, got: %v, want: %v", gotKeys, allExpected) + }) +} diff --git a/internal/k8sCommon/k8sclient/kubernetes_utils.go b/internal/k8sCommon/k8sclient/kubernetes_utils.go new file mode 100644 index 0000000000..e39a8edfb3 --- /dev/null +++ b/internal/k8sCommon/k8sclient/kubernetes_utils.go @@ -0,0 +1,231 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package k8sclient + +import ( + "errors" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "sync" + "time" + + corev1 "k8s.io/api/core/v1" +) + +const ( + // kubeAllowedStringAlphaNums holds the characters allowed in replicaset names from as parent deployment + // https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apimachinery/pkg/util/rand/rand.go#L121 + kubeAllowedStringAlphaNums = "bcdfghjklmnpqrstvwxz2456789" +) + +var ( + // ReplicaSet name = Deployment name + "-" + up to 10 alphanumeric characters string, if the ReplicaSet was created through a deployment + // The suffix string of the ReplicaSet name is an int32 number (0 to 4,294,967,295) that is cast to a string and then + // mapped to an alphanumeric value with only the following characters allowed: "bcdfghjklmnpqrstvwxz2456789". + // The suffix string length is therefore nondeterministic. The regex accepts a suffix of length 6-10 to account for + // ReplicaSets not managed by deployments that may have similar names. + // Suffix Generation: https://github.com/kubernetes/kubernetes/blob/master/pkg/controller/controller_utils.go#L1201 + // Alphanumeric Mapping: https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apimachinery/pkg/util/rand/rand.go#L121) + replicaSetWithDeploymentNamePattern = fmt.Sprintf(`^(.+)-[%s]{6,10}$`, kubeAllowedStringAlphaNums) + deploymentFromReplicaSetPattern = regexp.MustCompile(replicaSetWithDeploymentNamePattern) + // if a pod is launched directly by a replicaSet or daemonSet (with a given name by users), its name has the following pattern: + // Pod name = ReplicaSet name + 5 alphanumeric characters long string + // some code reference for daemon set: + // 1. daemonset uses the strategy to create pods: https://github.com/kubernetes/kubernetes/blob/82e3a671e79d1740ab9a3b3fac8a3bb7d065a6fb/pkg/registry/apps/daemonset/strategy.go#L46 + // 2. the strategy uses SimpleNameGenerator to create names: https://github.com/kubernetes/kubernetes/blob/82e3a671e79d1740ab9a3b3fac8a3bb7d065a6fb/staging/src/k8s.io/apiserver/pkg/storage/names/generate.go#L53 + // 3. the random name generator only use non vowels char + numbers: https://github.com/kubernetes/kubernetes/blob/82e3a671e79d1740ab9a3b3fac8a3bb7d065a6fb/staging/src/k8s.io/apimachinery/pkg/util/rand/rand.go#L83 + podWithSuffixPattern = fmt.Sprintf(`^(.+)-[%s]{5}$`, kubeAllowedStringAlphaNums) + replicaSetOrDaemonSetFromPodPattern = regexp.MustCompile(podWithSuffixPattern) + + // Pattern for StatefulSet: - + reStatefulSet = regexp.MustCompile(`^(.+)-(\d+)$`) +) + +func attachNamespace(resourceName, namespace string) string { + // character "@" is not allowed in kubernetes resource names: https://unofficial-kubernetes.readthedocs.io/en/latest/concepts/overview/working-with-objects/names/ + return resourceName + "@" + namespace +} + +func getServiceAndNamespace(service *corev1.Service) string { + return attachNamespace(service.Name, service.Namespace) +} + +func extractResourceAndNamespace(serviceOrWorkloadAndNamespace string) (string, string) { + // extract service name and namespace from serviceAndNamespace + parts := strings.Split(serviceOrWorkloadAndNamespace, "@") + if len(parts) != 2 { + return "", "" + } + return parts[0], parts[1] +} + +func extractWorkloadNameFromRS(replicaSetName string) (string, error) { + match := deploymentFromReplicaSetPattern.FindStringSubmatch(replicaSetName) + if match != nil { + return match[1], nil + } + + return "", errors.New("failed to extract workload name from replicatSet name: " + replicaSetName) +} + +func extractWorkloadNameFromPodName(podName string) (string, error) { + match := replicaSetOrDaemonSetFromPodPattern.FindStringSubmatch(podName) + if match != nil { + return match[1], nil + } + + return "", errors.New("failed to extract workload name from pod name: " + podName) +} + +func getWorkloadAndNamespace(pod *corev1.Pod) string { + var workloadAndNamespace string + if pod.ObjectMeta.OwnerReferences != nil { + for _, ownerRef := range pod.ObjectMeta.OwnerReferences { + if workloadAndNamespace != "" { + break + } + + if ownerRef.Kind == "ReplicaSet" { + if workloadName, err := extractWorkloadNameFromRS(ownerRef.Name); err == nil { + // when the replicaSet is created by a deployment, use deployment name + workloadAndNamespace = attachNamespace(workloadName, pod.Namespace) + } else if workloadName, err := extractWorkloadNameFromPodName(pod.Name); err == nil { + // when the replicaSet is not created by a deployment, use replicaSet name directly + workloadAndNamespace = attachNamespace(workloadName, pod.Namespace) + } + } else if ownerRef.Kind == "StatefulSet" { + workloadAndNamespace = attachNamespace(ownerRef.Name, pod.Namespace) + } else if ownerRef.Kind == "DaemonSet" { + workloadAndNamespace = attachNamespace(ownerRef.Name, pod.Namespace) + } + } + } + + return workloadAndNamespace +} + +// InferWorkloadName tries to parse the given podName to find the top-level workload name. +// +// 1) If it matches -, return . +// 2) If it matches -<5charSuffix>: +// - If is -<6–10charSuffix>, return . +// - Else return (likely a bare ReplicaSet or DaemonSet). +// +// 3) If no pattern matches, return the original podName. +// +// Caveat: You can't reliably distinguish DaemonSet vs. bare ReplicaSet by name alone. +// In some edge cases when the deployment name is longer than 47 characters, The regex pattern is +// not reliable. See reference: +// - https://pauldally.medium.com/why-you-try-to-keep-your-deployment-names-to-47-characters-or-less-1f93a848d34c +// - https://github.com/kubernetes/kubernetes/issues/116447#issuecomment-1530652258 +// +// For that, we fall back to use service name as last defense. +func inferWorkloadName(podName, fallbackServiceName string) string { + // 1) Check if it's a StatefulSet pod: - + if matches := reStatefulSet.FindStringSubmatch(podName); matches != nil { + return matches[1] // e.g. "mysql-0" => "mysql" + } + + // 2) Check if it's a Pod with a 5-char random suffix: -<5Chars> + if matches := replicaSetOrDaemonSetFromPodPattern.FindStringSubmatch(podName); matches != nil { + parentName := matches[1] + + // If parentName ends with 6–10 random chars, that parent is a Deployment-based ReplicaSet. + // So the top-level workload is the first part before that suffix. + if rsMatches := deploymentFromReplicaSetPattern.FindStringSubmatch(parentName); rsMatches != nil { + return rsMatches[1] // e.g. "nginx-a2b3c4" => "nginx" + } + + // Otherwise, it's a "bare" ReplicaSet or DaemonSet—just return parentName. + return parentName + } + + // 3) If none of the patterns matched, return the service name as fallback + if fallbackServiceName != "" { + return fallbackServiceName + } + + // 4) Finally return the full pod name (I don't think this will happen) + return podName +} + +const IP_PORT_PATTERN = `^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)$` + +var ipPortRegex = regexp.MustCompile(IP_PORT_PATTERN) + +func extractIPPort(ipPort string) (string, string, bool) { + match := ipPortRegex.MatchString(ipPort) + + if !match { + return "", "", false + } + + result := ipPortRegex.FindStringSubmatch(ipPort) + if len(result) != 3 { + return "", "", false + } + + ip := result[1] + port := result[2] + + return ip, port, true +} + +func getHostNetworkPorts(pod *corev1.Pod) []string { + var ports []string + if !pod.Spec.HostNetwork { + return ports + } + for _, container := range pod.Spec.Containers { + for _, port := range container.Ports { + if port.HostPort != 0 { + ports = append(ports, strconv.Itoa(int(port.HostPort))) + } + } + } + return ports +} + +func isIP(ipString string) bool { + ip := net.ParseIP(ipString) + return ip != nil +} + +// a safe channel which can be closed multiple times +type safeChannel struct { + sync.Mutex + + ch chan struct{} + closed bool +} + +func (sc *safeChannel) Close() { + sc.Lock() + defer sc.Unlock() + + if !sc.closed { + close(sc.ch) + sc.closed = true + } +} + +// Deleter represents a type that can delete a key from a map after a certain delay. +type Deleter interface { + DeleteWithDelay(m *sync.Map, key interface{}) +} + +// TimedDeleter deletes a key after a specified delay. +type TimedDeleter struct { + Delay time.Duration +} + +func (td *TimedDeleter) DeleteWithDelay(m *sync.Map, key interface{}) { + go func() { + time.Sleep(td.Delay) + m.Delete(key) + }() +} diff --git a/internal/k8sCommon/k8sclient/kubernetes_utils_test.go b/internal/k8sCommon/k8sclient/kubernetes_utils_test.go new file mode 100644 index 0000000000..9cce686ca7 --- /dev/null +++ b/internal/k8sCommon/k8sclient/kubernetes_utils_test.go @@ -0,0 +1,258 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: MIT + +package k8sclient + +import ( + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// TestAttachNamespace function +func TestAttachNamespace(t *testing.T) { + result := attachNamespace("testResource", "testNamespace") + if result != "testResource@testNamespace" { + t.Errorf("attachNamespace was incorrect, got: %s, want: %s.", result, "testResource@testNamespace") + } +} + +// TestGetServiceAndNamespace function +func TestGetServiceAndNamespace(t *testing.T) { + service := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testService", + Namespace: "testNamespace", + }, + } + result := getServiceAndNamespace(service) + if result != "testService@testNamespace" { + t.Errorf("getServiceAndNamespace was incorrect, got: %s, want: %s.", result, "testService@testNamespace") + } +} + +// TestExtractResourceAndNamespace function +func TestExtractResourceAndNamespace(t *testing.T) { + // Test normal case + name, namespace := extractResourceAndNamespace("testService@testNamespace") + if name != "testService" || namespace != "testNamespace" { + t.Errorf("extractResourceAndNamespace was incorrect, got: %s and %s, want: %s and %s.", name, namespace, "testService", "testNamespace") + } + + // Test invalid case + name, namespace = extractResourceAndNamespace("invalid") + if name != "" || namespace != "" { + t.Errorf("extractResourceAndNamespace was incorrect, got: %s and %s, want: %s and %s.", name, namespace, "", "") + } +} + +func TestExtractWorkloadNameFromRS(t *testing.T) { + testCases := []struct { + name string + replicaSetName string + want string + shouldErr bool + }{ + { + name: "Valid ReplicaSet Name", + replicaSetName: "my-deployment-5859ffc7ff", + want: "my-deployment", + shouldErr: false, + }, + { + name: "Invalid ReplicaSet Name - No Hyphen", + replicaSetName: "mydeployment5859ffc7ff", + want: "", + shouldErr: true, + }, + { + name: "Invalid ReplicaSet Name - Less Than 10 Suffix Characters", + replicaSetName: "my-deployment-bc2", + want: "", + shouldErr: true, + }, + { + name: "Invalid ReplicaSet Name - More Than 10 Suffix Characters", + replicaSetName: "my-deployment-5859ffc7ffx", + want: "", + shouldErr: true, + }, + { + name: "Invalid ReplicaSet Name - Invalid Characters in Suffix", + replicaSetName: "my-deployment-aeiou12345", + want: "", + shouldErr: true, + }, + { + name: "Invalid ReplicaSet Name - Empty String", + replicaSetName: "", + want: "", + shouldErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := extractWorkloadNameFromRS(tc.replicaSetName) + + if (err != nil) != tc.shouldErr { + t.Errorf("extractWorkloadNameFromRS() error = %v, wantErr %v", err, tc.shouldErr) + return + } + + if got != tc.want { + t.Errorf("extractWorkloadNameFromRS() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestExtractWorkloadNameFromPodName(t *testing.T) { + testCases := []struct { + name string + podName string + want string + shouldErr bool + }{ + { + name: "Valid Pod Name", + podName: "my-replicaset-bc24f", + want: "my-replicaset", + shouldErr: false, + }, + { + name: "Invalid Pod Name - No Hyphen", + podName: "myreplicasetbc24f", + want: "", + shouldErr: true, + }, + { + name: "Invalid Pod Name - Less Than 5 Suffix Characters", + podName: "my-replicaset-bc2", + want: "", + shouldErr: true, + }, + { + name: "Invalid Pod Name - More Than 5 Suffix Characters", + podName: "my-replicaset-bc24f5", + want: "", + shouldErr: true, + }, + { + name: "Invalid Pod Name - Empty String", + podName: "", + want: "", + shouldErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := extractWorkloadNameFromPodName(tc.podName) + + if (err != nil) != tc.shouldErr { + t.Errorf("extractWorkloadNameFromPodName() error = %v, wantErr %v", err, tc.shouldErr) + return + } + + if got != tc.want { + t.Errorf("extractWorkloadNameFromPodName() = %v, want %v", got, tc.want) + } + }) + } +} + +// TestGetWorkloadAndNamespace function +func TestGetWorkloadAndNamespace(t *testing.T) { + // Test ReplicaSet case + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testPod", + Namespace: "testNamespace", + OwnerReferences: []metav1.OwnerReference{ + { + Kind: "ReplicaSet", + Name: "testDeployment-5d68bc5f49", + }, + }, + }, + } + result := getWorkloadAndNamespace(pod) + if result != "testDeployment@testNamespace" { + t.Errorf("getDeploymentAndNamespace was incorrect, got: %s, want: %s.", result, "testDeployment@testNamespace") + } + + // Test StatefulSet case + pod.ObjectMeta.OwnerReferences[0].Kind = "StatefulSet" + pod.ObjectMeta.OwnerReferences[0].Name = "testStatefulSet" + result = getWorkloadAndNamespace(pod) + if result != "testStatefulSet@testNamespace" { + t.Errorf("getWorkloadAndNamespace was incorrect, got: %s, want: %s.", result, "testStatefulSet@testNamespace") + } + + // Test Other case + pod.ObjectMeta.OwnerReferences[0].Kind = "Other" + pod.ObjectMeta.OwnerReferences[0].Name = "testOther" + result = getWorkloadAndNamespace(pod) + if result != "" { + t.Errorf("getWorkloadAndNamespace was incorrect, got: %s, want: %s.", result, "") + } + + // Test no OwnerReferences case + pod.ObjectMeta.OwnerReferences = nil + result = getWorkloadAndNamespace(pod) + if result != "" { + t.Errorf("getWorkloadAndNamespace was incorrect, got: %s, want: %s.", result, "") + } +} + +func TestExtractIPPort(t *testing.T) { + // Test valid IP:Port + ip, port, ok := extractIPPort("192.0.2.0:8080") + assert.Equal(t, "192.0.2.0", ip) + assert.Equal(t, "8080", port) + assert.True(t, ok) + + // Test invalid IP:Port + ip, port, ok = extractIPPort("192.0.2:8080") + assert.Equal(t, "", ip) + assert.Equal(t, "", port) + assert.False(t, ok) + + // Test IP only + ip, port, ok = extractIPPort("192.0.2.0") + assert.Equal(t, "", ip) + assert.Equal(t, "", port) + assert.False(t, ok) +} + +func TestInferWorkloadName(t *testing.T) { + testCases := []struct { + name string + input string + service string + expected string + }{ + {"StatefulSet single digit", "mysql-0", "service", "mysql"}, + {"StatefulSet multiple digits", "mysql-10", "service", "mysql"}, + {"ReplicaSet bare pod", "nginx-b2dfg", "service", "nginx"}, + {"Deployment-based ReplicaSet pod", "nginx-76977669dc-lwx64", "service", "nginx"}, + {"Non matching", "simplepod", "service", "service"}, + {"ReplicaSet name with number suffix", "nginx-123-d9stt", "service", "nginx-123"}, + {"Some confusing case with a replicaSet/daemonset name matching the pattern", "nginx-245678-d9stt", "nginx-service", "nginx"}, + // when the regex pattern doesn't matter, we just fall back to service name to handle all the edge cases + {"Some confusing case with a replicaSet/daemonset name not matching the pattern", "nginx-123456-d9stt", "nginx-service", "nginx-123456"}, + {"Empty", "", "service", "service"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := inferWorkloadName(tc.input, tc.service) + if got != tc.expected { + t.Errorf("inferWorkloadName(%q) = %q; expected %q", tc.input, got, tc.expected) + } + }) + } +}