diff --git a/cloudsupport/cloudproviderconfiguration.go b/cloudsupport/cloudproviderconfiguration.go index 586f110..727d874 100644 --- a/cloudsupport/cloudproviderconfiguration.go +++ b/cloudsupport/cloudproviderconfiguration.go @@ -5,6 +5,8 @@ import ( "os" "strings" + corev1 "k8s.io/api/core/v1" + cloudsupportv1 "github.com/kubescape/k8s-interface/cloudsupport/v1" "github.com/kubescape/k8s-interface/k8sinterface" "github.com/kubescape/k8s-interface/workloadinterface" @@ -36,8 +38,8 @@ func GetKubeContextName() string { } // GetCloudProvider returns the cloud provider name -func GetCloudProvider() string { - return cloudsupportv1.GetCloudProvider() +func GetCloudProvider(nodeList *corev1.NodeList) string { + return cloudsupportv1.GetCloudProvider(nodeList) } // GetDescriptiveInfoFromCloudProvider returns the cluster description from the cloud provider wrapped in IMetadata obj diff --git a/cloudsupport/v1/cloudproviderv1.go b/cloudsupport/v1/cloudproviderv1.go index 13f8036..7f1a8ed 100644 --- a/cloudsupport/v1/cloudproviderv1.go +++ b/cloudsupport/v1/cloudproviderv1.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" + corev1 "k8s.io/api/core/v1" + "github.com/kubescape/k8s-interface/cloudsupport/apis" "github.com/kubescape/k8s-interface/k8sinterface" "github.com/kubescape/k8s-interface/workloadinterface" @@ -28,15 +30,15 @@ const ( NotSupportedMsg = "Not supported" ) -// GetCloudProvider get cloud provider name from gitVersion/server URL -func GetCloudProvider() string { +// GetCloudProvider get cloud provider name from gitVersion/nodes +func GetCloudProvider(nodeList *corev1.NodeList) string { if IsEKS() { return EKS } - if IsGKE() { + if IsGKE(nodeList) { return GKE } - if IsAKS() { + if IsAKS(nodeList) { return AKS } return "" @@ -324,10 +326,8 @@ func GetPolicyVersionAKS(aksSupport IAKSSupport, cluster string, subscriptionId } // check if the server is AKS. e.g. https://XXX.XX.XXX.azmk8s.io:443 -func IsAKS() bool { - const serverIdentifierAKS = "azmk8s.io" - clusterServerName := k8sinterface.GetK8sConfigClusterServerName() - return strings.Contains(clusterServerName, serverIdentifierAKS) +func IsAKS(nodeList *corev1.NodeList) bool { + return labelHasCloudPrefix(nodeList, "aks-") } // check if the server is EKS. e.g. arn:aws:eks:eu-west-1:xxx:cluster/xxxx @@ -340,10 +340,17 @@ func IsEKS() bool { } // check if the server is GKE. e.g. gke_xxx-xx-0000_us-central1-c_xxxx-1 -func IsGKE() bool { - version, err := k8sinterface.GetK8SServerGitVersion() - if err != nil { - return false +func IsGKE(nodeList *corev1.NodeList) bool { + return labelHasCloudPrefix(nodeList, "gke-") +} + +func labelHasCloudPrefix(nodeList *corev1.NodeList, cloud string) bool { + for _, node := range nodeList.Items { + if val, ok := node.Labels["kubernetes.io/hostname"]; ok { + if strings.HasPrefix(val, cloud) { + return true + } + } } - return strings.Contains(version, GKE) + return false } diff --git a/cloudsupport/v1/cloudproviderv1_test.go b/cloudsupport/v1/cloudproviderv1_test.go index c99af90..9ab3044 100644 --- a/cloudsupport/v1/cloudproviderv1_test.go +++ b/cloudsupport/v1/cloudproviderv1_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "testing" + corev1 "k8s.io/api/core/v1" + "github.com/kubescape/k8s-interface/cloudsupport/apis" "github.com/kubescape/k8s-interface/k8sinterface" "github.com/stretchr/testify/assert" @@ -299,8 +301,9 @@ func Test_IsGKE(t *testing.T) { defer tearDown() type args struct { - config *clientcmdapi.Config - context string + config *clientcmdapi.Config + context string + labelMap map[string]string } tests := []struct { name string @@ -312,17 +315,25 @@ func Test_IsGKE(t *testing.T) { args: args{ config: getKubeConfigMock(), context: "gke_xxx-xx-0000_us-central1-c_xxxx-1", + labelMap: map[string]string{ + "kubernetes.io/hostname": "gke-xxx-xx-0000_us-central1-c_xxxx-1", + }, }, want: true, }, } for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { // set context k8sinterface.SetClientConfigAPI(tt.args.config) k8sinterface.SetK8SGitServerVersion("gke_xxx-xx-0000_us-central1-c_xxxx-1") - if got := IsGKE(); got != tt.want { + + node := corev1.Node{} + node.Labels = tt.args.labelMap + nodeList := &corev1.NodeList{} + nodeList.Items = []corev1.Node{node} + + if got := IsGKE(nodeList); got != tt.want { t.Errorf("IsGKE() = %v, want %v", got, tt.want) } }) @@ -363,41 +374,6 @@ func Test_IsEKS(t *testing.T) { } } -func Test_IsAKS(t *testing.T) { - defer tearDown() - - type args struct { - config *clientcmdapi.Config - context string - } - tests := []struct { - name string - args args - want bool - }{ - { - name: "Test_IsAKS", - args: args{ - config: getKubeConfigMock(), - context: "xxxx-2", - }, - want: true, - }, - } - for _, tt := range tests { - - t.Run(tt.name, func(t *testing.T) { - // set context - k8sinterface.SetClientConfigAPI(tt.args.config) - k8sinterface.SetClusterContextName(tt.args.context) - k8sinterface.SetConfigClusterServerName("https://XXX.XX.XXX.azmk8s.io:443") - if got := IsAKS(); got != tt.want { - t.Errorf("IsAKS() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_GetK8sConfigClusterServerName2(t *testing.T) { defer tearDown() @@ -488,6 +464,7 @@ func TestGetCloudProvider(t *testing.T) { context string expected string gitVersion string + labelMap map[string]string }{ { name: "AKS", @@ -495,6 +472,9 @@ func TestGetCloudProvider(t *testing.T) { context: "0-context", expected: AKS, gitVersion: "v1", + labelMap: map[string]string{ + "kubernetes.io/hostname": "aks-agentpool-xxxx-0", + }, }, { name: "EKS", @@ -509,6 +489,9 @@ func TestGetCloudProvider(t *testing.T) { context: "2-context", expected: GKE, gitVersion: "gke", + labelMap: map[string]string{ + "kubernetes.io/hostname": "gke-agentpool-xxxx-0", + }, }, { name: "Unknown", @@ -525,7 +508,13 @@ func TestGetCloudProvider(t *testing.T) { k8sinterface.SetClusterContextName(tt.context) k8sinterface.SetK8SGitServerVersion(tt.gitVersion) k8sinterface.SetClientConfigAPI(tt.config) - if got := GetCloudProvider(); got != tt.expected { + + node := corev1.Node{} + node.Labels = tt.labelMap + nodeList := &corev1.NodeList{} + nodeList.Items = []corev1.Node{node} + + if got := GetCloudProvider(nodeList); got != tt.expected { t.Errorf("GetCloudProvider() = %v, want %v", got, tt.expected) } })