Skip to content

Commit

Permalink
get cloud provider from node
Browse files Browse the repository at this point in the history
Signed-off-by: David Wertenteil <[email protected]>
  • Loading branch information
David Wertenteil committed Oct 22, 2023
1 parent 42ff7b5 commit b4b054a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 55 deletions.
6 changes: 4 additions & 2 deletions cloudsupport/cloudproviderconfiguration.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"os"
"strings"

corev1 "k8s.io/api/core/v1"

cloudsupportv1 "github.com/kubescape/k8s-interface/cloudsupport/v1"
"github.com/kubescape/k8s-interface/k8sinterface"
"github.com/kubescape/k8s-interface/workloadinterface"
Expand Down Expand Up @@ -36,8 +38,8 @@ func GetKubeContextName() string {
}

// GetCloudProvider returns the cloud provider name
func GetCloudProvider() string {
return cloudsupportv1.GetCloudProvider()
func GetCloudProvider(nodeList *corev1.NodeList) string {
return cloudsupportv1.GetCloudProvider(nodeList)
}

// GetDescriptiveInfoFromCloudProvider returns the cluster description from the cloud provider wrapped in IMetadata obj
Expand Down
33 changes: 20 additions & 13 deletions cloudsupport/v1/cloudproviderv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"strings"

corev1 "k8s.io/api/core/v1"

"github.com/kubescape/k8s-interface/cloudsupport/apis"
"github.com/kubescape/k8s-interface/k8sinterface"
"github.com/kubescape/k8s-interface/workloadinterface"
Expand All @@ -28,15 +30,15 @@ const (
NotSupportedMsg = "Not supported"
)

// GetCloudProvider get cloud provider name from gitVersion/server URL
func GetCloudProvider() string {
// GetCloudProvider get cloud provider name from gitVersion/nodes
func GetCloudProvider(nodeList *corev1.NodeList) string {
if IsEKS() {
return EKS
}
if IsGKE() {
if IsGKE(nodeList) {
return GKE
}
if IsAKS() {
if IsAKS(nodeList) {
return AKS
}
return ""
Expand Down Expand Up @@ -324,10 +326,8 @@ func GetPolicyVersionAKS(aksSupport IAKSSupport, cluster string, subscriptionId
}

// check if the server is AKS. e.g. https://XXX.XX.XXX.azmk8s.io:443
func IsAKS() bool {
const serverIdentifierAKS = "azmk8s.io"
clusterServerName := k8sinterface.GetK8sConfigClusterServerName()
return strings.Contains(clusterServerName, serverIdentifierAKS)
func IsAKS(nodeList *corev1.NodeList) bool {
return labelHasCloudPrefix(nodeList, "aks-")
}

// check if the server is EKS. e.g. arn:aws:eks:eu-west-1:xxx:cluster/xxxx
Expand All @@ -340,10 +340,17 @@ func IsEKS() bool {
}

// check if the server is GKE. e.g. gke_xxx-xx-0000_us-central1-c_xxxx-1
func IsGKE() bool {
version, err := k8sinterface.GetK8SServerGitVersion()
if err != nil {
return false
func IsGKE(nodeList *corev1.NodeList) bool {
return labelHasCloudPrefix(nodeList, "gke-")
}

func labelHasCloudPrefix(nodeList *corev1.NodeList, cloud string) bool {
for _, node := range nodeList.Items {
if val, ok := node.Labels["kubernetes.io/hostname"]; ok {
if strings.HasPrefix(val, cloud) {
return true
}
}
}
return strings.Contains(version, GKE)
return false
}
69 changes: 29 additions & 40 deletions cloudsupport/v1/cloudproviderv1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"testing"

corev1 "k8s.io/api/core/v1"

"github.com/kubescape/k8s-interface/cloudsupport/apis"
"github.com/kubescape/k8s-interface/k8sinterface"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -299,8 +301,9 @@ func Test_IsGKE(t *testing.T) {
defer tearDown()

type args struct {
config *clientcmdapi.Config
context string
config *clientcmdapi.Config
context string
labelMap map[string]string
}
tests := []struct {
name string
Expand All @@ -312,17 +315,25 @@ func Test_IsGKE(t *testing.T) {
args: args{
config: getKubeConfigMock(),
context: "gke_xxx-xx-0000_us-central1-c_xxxx-1",
labelMap: map[string]string{
"kubernetes.io/hostname": "gke-xxx-xx-0000_us-central1-c_xxxx-1",
},
},
want: true,
},
}
for _, tt := range tests {

t.Run(tt.name, func(t *testing.T) {
// set context
k8sinterface.SetClientConfigAPI(tt.args.config)
k8sinterface.SetK8SGitServerVersion("gke_xxx-xx-0000_us-central1-c_xxxx-1")
if got := IsGKE(); got != tt.want {

node := corev1.Node{}
node.Labels = tt.args.labelMap
nodeList := &corev1.NodeList{}
nodeList.Items = []corev1.Node{node}

if got := IsGKE(nodeList); got != tt.want {
t.Errorf("IsGKE() = %v, want %v", got, tt.want)
}
})
Expand Down Expand Up @@ -363,41 +374,6 @@ func Test_IsEKS(t *testing.T) {
}
}

func Test_IsAKS(t *testing.T) {
defer tearDown()

type args struct {
config *clientcmdapi.Config
context string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "Test_IsAKS",
args: args{
config: getKubeConfigMock(),
context: "xxxx-2",
},
want: true,
},
}
for _, tt := range tests {

t.Run(tt.name, func(t *testing.T) {
// set context
k8sinterface.SetClientConfigAPI(tt.args.config)
k8sinterface.SetClusterContextName(tt.args.context)
k8sinterface.SetConfigClusterServerName("https://XXX.XX.XXX.azmk8s.io:443")
if got := IsAKS(); got != tt.want {
t.Errorf("IsAKS() = %v, want %v", got, tt.want)
}
})
}
}

func Test_GetK8sConfigClusterServerName2(t *testing.T) {
defer tearDown()

Expand Down Expand Up @@ -488,13 +464,17 @@ func TestGetCloudProvider(t *testing.T) {
context string
expected string
gitVersion string
labelMap map[string]string
}{
{
name: "AKS",
config: configMock,
context: "0-context",
expected: AKS,
gitVersion: "v1",
labelMap: map[string]string{
"kubernetes.io/hostname": "aks-agentpool-xxxx-0",
},
},
{
name: "EKS",
Expand All @@ -509,6 +489,9 @@ func TestGetCloudProvider(t *testing.T) {
context: "2-context",
expected: GKE,
gitVersion: "gke",
labelMap: map[string]string{
"kubernetes.io/hostname": "gke-agentpool-xxxx-0",
},
},
{
name: "Unknown",
Expand All @@ -525,7 +508,13 @@ func TestGetCloudProvider(t *testing.T) {
k8sinterface.SetClusterContextName(tt.context)
k8sinterface.SetK8SGitServerVersion(tt.gitVersion)
k8sinterface.SetClientConfigAPI(tt.config)
if got := GetCloudProvider(); got != tt.expected {

node := corev1.Node{}
node.Labels = tt.labelMap
nodeList := &corev1.NodeList{}
nodeList.Items = []corev1.Node{node}

if got := GetCloudProvider(nodeList); got != tt.expected {
t.Errorf("GetCloudProvider() = %v, want %v", got, tt.expected)
}
})
Expand Down

0 comments on commit b4b054a

Please sign in to comment.