diff --git a/aws/cf.go b/aws/cf.go index c9ed72a7..6864290e 100644 --- a/aws/cf.go +++ b/aws/cf.go @@ -35,6 +35,7 @@ type Stack struct { WAFWebACLID string CertificateARNs map[string]time.Time tags map[string]string + Subnets []string } // IsComplete returns true if the stack status is a complete state. @@ -480,6 +481,11 @@ func mapToManagedStack(stack *cloudformation.Stack) *Stack { http2 = false } + var subnets []string + if parameters[parameterLoadBalancerSubnetsParameter] != "" { + subnets = strings.Split(parameters[parameterLoadBalancerSubnetsParameter], ",") + } + return &Stack{ Name: aws.StringValue(stack.StackName), DNSName: outputs.dnsName(), @@ -497,6 +503,7 @@ func mapToManagedStack(stack *cloudformation.Stack) *Stack { statusReason: aws.StringValue(stack.StackStatusReason), CWAlarmConfigHash: tags[cwAlarmConfigHashTag], WAFWebACLID: parameters[parameterLoadBalancerWAFWebACLIDParameter], + Subnets: subnets, } } diff --git a/go.mod b/go.mod index 380617fd..5c95dd37 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.9.0 github.com/zalando/skipper v0.21.54 + golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f k8s.io/api v0.28.9 k8s.io/apimachinery v0.28.9 k8s.io/client-go v0.28.9 @@ -70,14 +71,12 @@ require ( go.uber.org/atomic v1.11.0 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect golang.org/x/crypto v0.22.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/net v0.24.0 // indirect golang.org/x/oauth2 v0.19.0 // indirect golang.org/x/sys v0.19.0 // indirect golang.org/x/term v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.20.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240415180920-8c6c420018be // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240415180920-8c6c420018be // indirect google.golang.org/grpc v1.63.2 // indirect diff --git a/go.sum b/go.sum index b014966d..0f223f9a 100644 --- a/go.sum +++ b/go.sum @@ -276,8 +276,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= diff --git a/worker.go b/worker.go index 86e9beea..692d5cec 100644 --- a/worker.go +++ b/worker.go @@ -19,6 +19,7 @@ import ( "github.com/zalando-incubator/kube-ingress-aws-controller/certs" "github.com/zalando-incubator/kube-ingress-aws-controller/kubernetes" "github.com/zalando-incubator/kube-ingress-aws-controller/problem" + "golang.org/x/exp/slices" ) type loadBalancer struct { @@ -333,7 +334,7 @@ func doWork( awsAdapter.UpdateTargetGroupsAndAutoScalingGroups(stacks, problems) certs := NewCertificates(certificateSummaries) - model := buildManagedModel(certs, certsPerALB, certTTL, ingresses, stacks, cwAlarms, globalWAFACL) + model := buildManagedModel(certs, certsPerALB, certTTL, ingresses, stacks, cwAlarms, globalWAFACL, awsAdapter.FindLBSubnets) log.Debugf("Have %d model(s)", len(model)) for _, loadBalancer := range model { switch loadBalancer.Status() { @@ -408,6 +409,7 @@ func matchIngressesToLoadBalancers( certs CertificatesFinder, certsPerALB int, ingresses []*kubernetes.Ingress, + subnetsByScheme func(scheme string) []string, ) []*loadBalancer { clusterLocalLB := &loadBalancer{ clusterLocal: true, @@ -451,6 +453,17 @@ func matchIngressesToLoadBalancers( continue } + // Ignore NLBs with a wrong set of subnets + if lb.loadBalancerType == aws.LoadBalancerTypeNetwork && lb.stack != nil { + subnets := subnetsByScheme(lb.scheme) + sort.Strings(subnets) + sort.Strings(lb.stack.Subnets) + + if !slices.Equal[[]string](lb.stack.Subnets, subnets) { + continue + } + } + if lb.addIngress(certificateARNs, ingress, certsPerALB) { added = true break @@ -516,11 +529,12 @@ func buildManagedModel( stacks []*aws.Stack, cwAlarms aws.CloudWatchAlarmList, globalWAFACL string, + subnetsByScheme func(scheme string) []string, ) []*loadBalancer { sortStacks(stacks) attachGlobalWAFACL(ingresses, globalWAFACL) model := getAllLoadBalancers(certs, certTTL, stacks) - model = matchIngressesToLoadBalancers(model, certs, certsPerALB, ingresses) + model = matchIngressesToLoadBalancers(model, certs, certsPerALB, ingresses, subnetsByScheme) attachCloudWatchAlarms(model, cwAlarms) return model diff --git a/worker_test.go b/worker_test.go index bb0ff981..38f4f2e6 100644 --- a/worker_test.go +++ b/worker_test.go @@ -18,10 +18,12 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zalando-incubator/kube-ingress-aws-controller/aws" awsAdapter "github.com/zalando-incubator/kube-ingress-aws-controller/aws" "github.com/zalando-incubator/kube-ingress-aws-controller/certs" "github.com/zalando-incubator/kube-ingress-aws-controller/kubernetes" "github.com/zalando/skipper/dataclients/kubernetes/kubernetestest" + "golang.org/x/exp/slices" "k8s.io/apimachinery/pkg/util/wait" "github.com/zalando-incubator/kube-ingress-aws-controller/aws/fake" @@ -1179,6 +1181,7 @@ func TestMatchIngressesToLoadbalancers(t *testing.T) { maxCertsPerLB int lbs []*loadBalancer ingresses []*kubernetes.Ingress + subnets []string validate func(*testing.T, []*loadBalancer) }{{ title: "only cluster local", @@ -1361,6 +1364,37 @@ func TestMatchIngressesToLoadbalancers(t *testing.T) { require.Equal(t, 1, len(lb.ingresses["foo"])) } }, + }, { + title: "load balancer with invalid subnets", + ingresses: []*kubernetes.Ingress{{ + Name: "foo-ingress", + Hostnames: []string{ + "foo.org", + "bar.org", + }, + LoadBalancerType: awsAdapter.LoadBalancerTypeNetwork, + Shared: true, + }}, + lbs: []*loadBalancer{{ + loadBalancerType: awsAdapter.LoadBalancerTypeNetwork, + ingresses: make(map[string][]*kubernetes.Ingress), + stack: &aws.Stack{Subnets: []string{"a", "b", "c"}}, + }}, + validate: func(t *testing.T, lbs []*loadBalancer) { + require.Equal(t, 3, len(lbs)) + for _, lb := range lbs { + if lb.clusterLocal { + continue + } + + if lb.stack != nil && slices.Equal[[]string](lb.stack.Subnets, []string{"a", "b", "c"}) { + require.Len(t, lb.ingresses, 0) + } else { + require.Len(t, lb.ingresses, 1) + } + } + }, + subnets: []string{"x", "y", "z"}, }} { t.Run(test.title, func(t *testing.T) { var certs CertificatesFinder = defaultCerts @@ -1373,7 +1407,11 @@ func TestMatchIngressesToLoadbalancers(t *testing.T) { maxCertsPerLB = test.maxCertsPerLB } - lbs := matchIngressesToLoadBalancers(test.lbs, certs, maxCertsPerLB, test.ingresses) + subnetsByScheme := func(scheme string) []string { + return test.subnets + } + + lbs := matchIngressesToLoadBalancers(test.lbs, certs, maxCertsPerLB, test.ingresses, subnetsByScheme) test.validate(t, lbs) }) } @@ -1403,6 +1441,7 @@ func TestBuildModel(t *testing.T) { stacks []*awsAdapter.Stack alarms awsAdapter.CloudWatchAlarmList globalWAFACL string + subnets []string validate func(*testing.T, []*loadBalancer) }{{ title: "no alarm, no waf", @@ -1549,6 +1588,10 @@ func TestBuildModel(t *testing.T) { maxCertsPerLB = test.maxCertsPerLB } + subnetsByScheme := func(scheme string) []string { + return test.subnets + } + m := buildManagedModel( certs, maxCertsPerLB, @@ -1557,6 +1600,7 @@ func TestBuildModel(t *testing.T) { test.stacks, test.alarms, test.globalWAFACL, + subnetsByScheme, ) test.validate(t, m)