From 8e2e2d874e60d06701023614150efbd5f1bc3fe9 Mon Sep 17 00:00:00 2001 From: Saylor Berman Date: Thu, 12 Dec 2024 17:08:28 -0700 Subject: [PATCH] Use state file for updating N+ upstreams (#2897) Problem: When splitting the data and control plane, we want to be able to maintain the ability to dynamically update upstreams without a reload if using NGINX Plus. With the current process, we write the upstreams into the nginx conf but don't reload, then call the API to actually do the update. Writing into the conf will trigger a reload from the agent once we split, however. There were also two bugs: - if metrics were disabled, the nginx plus client was not initialized, preventing API calls from occuring and instead a reload occurred - stream upstreams were not updated with the API Solution: Don't write the upstream servers into the nginx conf anymore when using NGINX Plus. Instead, utilize the `state` file option that the NGINX Plus API will populate with the upstream servers. This way we can just call the API and don't unintentionally reload by writing servers into the conf. Also added support for updating stream upstreams, and fixed the initialization bug. --- internal/mode/static/handler.go | 148 +++++++++------- internal/mode/static/handler_test.go | 122 ++++++++----- internal/mode/static/manager.go | 10 +- internal/mode/static/nginx/config/convert.go | 40 ++++- .../mode/static/nginx/config/convert_test.go | 34 ++++ .../mode/static/nginx/config/http/config.go | 7 +- .../mode/static/nginx/config/stream/config.go | 7 +- .../mode/static/nginx/config/upstreams.go | 25 ++- .../static/nginx/config/upstreams_template.go | 7 +- .../static/nginx/config/upstreams_test.go | 70 ++++++-- internal/mode/static/nginx/runtime/manager.go | 67 +++++-- .../mode/static/nginx/runtime/manager_test.go | 97 ++++++++-- .../runtime/runtimefakes/fake_manager.go | 114 ++++++++++-- .../runtimefakes/fake_nginx_plus_client.go | 166 ++++++++++++++++++ 14 files changed, 725 insertions(+), 189 deletions(-) diff --git a/internal/mode/static/handler.go b/internal/mode/static/handler.go index e09732e7f0..2a835b8584 100644 --- a/internal/mode/static/handler.go +++ b/internal/mode/static/handler.go @@ -2,6 +2,7 @@ package static import ( "context" + "errors" "fmt" "sync" "time" @@ -182,11 +183,11 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log h.setLatestConfiguration(&cfg) - err = h.updateUpstreamServers( - ctx, - logger, - cfg, - ) + if h.cfg.plus { + err = h.updateUpstreamServers(cfg) + } else { + err = h.updateNginxConf(ctx, cfg) + } case state.ClusterStateChange: h.version++ cfg := dataplane.BuildConfiguration(ctx, gr, h.cfg.serviceResolver, h.version) @@ -198,10 +199,7 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log h.setLatestConfiguration(&cfg) - err = h.updateNginxConf( - ctx, - cfg, - ) + err = h.updateNginxConf(ctx, cfg) } var nginxReloadRes status.NginxReloadResult @@ -306,7 +304,10 @@ func (h *eventHandlerImpl) parseAndCaptureEvent(ctx context.Context, logger logr } // updateNginxConf updates nginx conf files and reloads nginx. -func (h *eventHandlerImpl) updateNginxConf(ctx context.Context, conf dataplane.Configuration) error { +func (h *eventHandlerImpl) updateNginxConf( + ctx context.Context, + conf dataplane.Configuration, +) error { files := h.cfg.generator.Generate(conf) if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil { return fmt.Errorf("failed to replace NGINX configuration files: %w", err) @@ -316,89 +317,114 @@ func (h *eventHandlerImpl) updateNginxConf(ctx context.Context, conf dataplane.C return fmt.Errorf("failed to reload NGINX: %w", err) } + // If using NGINX Plus, update upstream servers using the API. + if err := h.updateUpstreamServers(conf); err != nil { + return fmt.Errorf("failed to update upstream servers: %w", err) + } + return nil } -// updateUpstreamServers is called only when endpoints have changed. It updates nginx conf files and then: -// - if using NGINX Plus, determines which servers have changed and uses the N+ API to update them; -// - otherwise if not using NGINX Plus, or an error was returned from the API, reloads nginx. -func (h *eventHandlerImpl) updateUpstreamServers( - ctx context.Context, - logger logr.Logger, - conf dataplane.Configuration, -) error { - isPlus := h.cfg.nginxRuntimeMgr.IsPlus() - - files := h.cfg.generator.Generate(conf) - if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil { - return fmt.Errorf("failed to replace NGINX configuration files: %w", err) +// updateUpstreamServers determines which servers have changed and uses the NGINX Plus API to update them. +// Only applicable when using NGINX Plus. +func (h *eventHandlerImpl) updateUpstreamServers(conf dataplane.Configuration) error { + if !h.cfg.plus { + return nil } - reload := func() error { - if err := h.cfg.nginxRuntimeMgr.Reload(ctx, conf.Version); err != nil { - return fmt.Errorf("failed to reload NGINX: %w", err) - } + prevUpstreams, prevStreamUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams() + if err != nil { + return fmt.Errorf("failed to get upstreams from API: %w", err) + } - return nil + type upstream struct { + name string + servers []ngxclient.UpstreamServer } + var upstreams []upstream - if isPlus { - type upstream struct { - name string - servers []ngxclient.UpstreamServer + for _, u := range conf.Upstreams { + confUpstream := upstream{ + name: u.Name, + servers: ngxConfig.ConvertEndpoints(u.Endpoints), } - var upstreams []upstream - prevUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams() - if err != nil { - logger.Error(err, "failed to get upstreams from API, reloading configuration instead") - return reload() + if u, ok := prevUpstreams[confUpstream.name]; ok { + if !serversEqual(confUpstream.servers, u.Peers) { + upstreams = append(upstreams, confUpstream) + } } + } - for _, u := range conf.Upstreams { - confUpstream := upstream{ - name: u.Name, - servers: ngxConfig.ConvertEndpoints(u.Endpoints), - } + type streamUpstream struct { + name string + servers []ngxclient.StreamUpstreamServer + } + var streamUpstreams []streamUpstream - if u, ok := prevUpstreams[confUpstream.name]; ok { - if !serversEqual(confUpstream.servers, u.Peers) { - upstreams = append(upstreams, confUpstream) - } - } + for _, u := range conf.StreamUpstreams { + confUpstream := streamUpstream{ + name: u.Name, + servers: ngxConfig.ConvertStreamEndpoints(u.Endpoints), } - var reloadPlus bool - for _, upstream := range upstreams { - if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil { - logger.Error( - err, "couldn't update upstream via the API, reloading configuration instead", - "upstreamName", upstream.name, - ) - reloadPlus = true + if u, ok := prevStreamUpstreams[confUpstream.name]; ok { + if !serversEqual(confUpstream.servers, u.Peers) { + streamUpstreams = append(streamUpstreams, confUpstream) } } + } - if !reloadPlus { - return nil + var updateErr error + for _, upstream := range upstreams { + if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil { + updateErr = errors.Join(updateErr, fmt.Errorf( + "couldn't update upstream %q via the API: %w", upstream.name, err)) } } - return reload() + for _, upstream := range streamUpstreams { + if err := h.cfg.nginxRuntimeMgr.UpdateStreamServers(upstream.name, upstream.servers); err != nil { + updateErr = errors.Join(updateErr, fmt.Errorf( + "couldn't update stream upstream %q via the API: %w", upstream.name, err)) + } + } + + return updateErr } -func serversEqual(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer) bool { +// serversEqual accepts lists of either UpstreamServer/Peer or StreamUpstreamServer/StreamPeer and determines +// if the server names within these lists are equal. +func serversEqual[ + upstreamServer ngxclient.UpstreamServer | ngxclient.StreamUpstreamServer, + peer ngxclient.Peer | ngxclient.StreamPeer, +](newServers []upstreamServer, oldServers []peer) bool { if len(newServers) != len(oldServers) { return false } + getServerVal := func(T any) string { + var server string + switch t := T.(type) { + case ngxclient.UpstreamServer: + server = t.Server + case ngxclient.StreamUpstreamServer: + server = t.Server + case ngxclient.Peer: + server = t.Server + case ngxclient.StreamPeer: + server = t.Server + } + return server + } + diff := make(map[string]struct{}, len(newServers)) for _, s := range newServers { - diff[s.Server] = struct{}{} + diff[getServerVal(s)] = struct{}{} } for _, s := range oldServers { - if _, ok := diff[s.Server]; !ok { + if _, ok := diff[getServerVal(s)]; !ok { return false } } diff --git a/internal/mode/static/handler_test.go b/internal/mode/static/handler_test.go index 67bf0e8e0e..c24f5d27d2 100644 --- a/internal/mode/static/handler_test.go +++ b/internal/mode/static/handler_test.go @@ -423,20 +423,29 @@ var _ = Describe("eventHandler", func() { }, }, } - fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil) + + streamUpstreams := ngxclient.StreamUpstreams{ + "two": ngxclient.StreamUpstream{ + Peers: []ngxclient.StreamPeer{ + {Server: "server2"}, + }, + }, + } + + fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil) }) When("running NGINX Plus", func() { It("should call the NGINX Plus API", func() { - fakeNginxRuntimeMgr.IsPlusReturns(true) + handler.cfg.plus = true handler.HandleEventBatch(context.Background(), ctlrZap.New(), batch) dcfg := dataplane.GetDefaultConfiguration(&graph.Graph{}, 1) Expect(helpers.Diff(handler.GetLatestConfiguration(), &dcfg)).To(BeEmpty()) - Expect(fakeGenerator.GenerateCallCount()).To(Equal(1)) - Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(1)) + Expect(fakeGenerator.GenerateCallCount()).To(Equal(0)) + Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(0)) Expect(fakeNginxRuntimeMgr.GetUpstreamsCallCount()).To(Equal(1)) }) }) @@ -463,19 +472,11 @@ var _ = Describe("eventHandler", func() { Name: "one", }, }, - } - - type callCounts struct { - generate int - update int - reload int - } - - assertCallCounts := func(cc callCounts) { - Expect(fakeGenerator.GenerateCallCount()).To(Equal(cc.generate)) - Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(cc.generate)) - Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(cc.update)) - Expect(fakeNginxRuntimeMgr.ReloadCallCount()).To(Equal(cc.reload)) + StreamUpstreams: []dataplane.Upstream{ + { + Name: "two", + }, + }, } BeforeEach(func() { @@ -486,47 +487,49 @@ var _ = Describe("eventHandler", func() { }, }, } - fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil) + + streamUpstreams := ngxclient.StreamUpstreams{ + "two": ngxclient.StreamUpstream{ + Peers: []ngxclient.StreamPeer{ + {Server: "server2"}, + }, + }, + } + + fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil) }) When("running NGINX Plus", func() { BeforeEach(func() { - fakeNginxRuntimeMgr.IsPlusReturns(true) + handler.cfg.plus = true }) It("should update servers using the NGINX Plus API", func() { - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 1, reload: 0}) + Expect(handler.updateUpstreamServers(conf)).To(Succeed()) + Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(1)) }) - It("should reload when GET API returns an error", func() { - fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, errors.New("error")) - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 0, reload: 1}) + It("should return error when GET API returns an error", func() { + fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, nil, errors.New("error")) + Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed()) }) - It("should reload when POST API returns an error", func() { + It("should return error when UpdateHTTPServers API returns an error", func() { fakeNginxRuntimeMgr.UpdateHTTPServersReturns(errors.New("error")) - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) + Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed()) + }) - assertCallCounts(callCounts{generate: 1, update: 1, reload: 1}) + It("should return error when UpdateStreamServers API returns an error", func() { + fakeNginxRuntimeMgr.UpdateStreamServersReturns(errors.New("error")) + Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed()) }) }) When("not running NGINX Plus", func() { - It("should update servers by reloading", func() { - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 0, reload: 1}) - }) + It("should not do anything", func() { + Expect(handler.updateUpstreamServers(conf)).To(Succeed()) - It("should return an error when reloading fails", func() { - fakeNginxRuntimeMgr.ReloadReturns(errors.New("error")) - Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).ToNot(Succeed()) - - assertCallCounts(callCounts{generate: 1, update: 0, reload: 1}) + Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(0)) }) }) }) @@ -612,7 +615,7 @@ var _ = Describe("eventHandler", func() { }) var _ = Describe("serversEqual", func() { - DescribeTable("determines if server lists are equal", + DescribeTable("determines if HTTP server lists are equal", func(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer, equal bool) { Expect(serversEqual(newServers, oldServers)).To(Equal(equal)) }, @@ -649,6 +652,43 @@ var _ = Describe("serversEqual", func() { true, ), ) + DescribeTable("determines if stream server lists are equal", + func(newServers []ngxclient.StreamUpstreamServer, oldServers []ngxclient.StreamPeer, equal bool) { + Expect(serversEqual(newServers, oldServers)).To(Equal(equal)) + }, + Entry("different length", + []ngxclient.StreamUpstreamServer{ + {Server: "server1"}, + }, + []ngxclient.StreamPeer{ + {Server: "server1"}, + {Server: "server2"}, + }, + false, + ), + Entry("differing elements", + []ngxclient.StreamUpstreamServer{ + {Server: "server1"}, + {Server: "server2"}, + }, + []ngxclient.StreamPeer{ + {Server: "server1"}, + {Server: "server3"}, + }, + false, + ), + Entry("same elements", + []ngxclient.StreamUpstreamServer{ + {Server: "server1"}, + {Server: "server2"}, + }, + []ngxclient.StreamPeer{ + {Server: "server1"}, + {Server: "server2"}, + }, + true, + ), + ) }) var _ = Describe("getGatewayAddresses", func() { diff --git a/internal/mode/static/manager.go b/internal/mode/static/manager.go index bc24210318..bc94e61346 100644 --- a/internal/mode/static/manager.go +++ b/internal/mode/static/manager.go @@ -172,15 +172,17 @@ func StartManager(cfg config.Config) error { ) var ngxPlusClient ngxruntime.NginxPlusClient + if cfg.Plus { + ngxPlusClient, err = ngxruntime.CreatePlusClient() + if err != nil { + return fmt.Errorf("error creating NGINX plus client: %w", err) + } + } if cfg.MetricsConfig.Enabled { constLabels := map[string]string{"class": cfg.GatewayClassName} var ngxCollector prometheus.Collector if cfg.Plus { - ngxPlusClient, err = ngxruntime.CreatePlusClient() - if err != nil { - return fmt.Errorf("error creating NGINX plus client: %w", err) - } ngxCollector, err = collectors.NewNginxPlusMetricsCollector(ngxPlusClient, constLabels, promLogger) } else { ngxCollector = collectors.NewNginxMetricsCollector(constLabels, promLogger) diff --git a/internal/mode/static/nginx/config/convert.go b/internal/mode/static/nginx/config/convert.go index ff20bf888d..3038149a0e 100644 --- a/internal/mode/static/nginx/config/convert.go +++ b/internal/mode/static/nginx/config/convert.go @@ -13,17 +13,26 @@ func ConvertEndpoints(eps []resolver.Endpoint) []ngxclient.UpstreamServer { servers := make([]ngxclient.UpstreamServer, 0, len(eps)) for _, ep := range eps { - var port string - if ep.Port != 0 { - port = fmt.Sprintf(":%d", ep.Port) - } + port, format := getPortAndIPFormat(ep) - format := "%s%s" - if ep.IPv6 { - format = "[%s]%s" + server := ngxclient.UpstreamServer{ + Server: fmt.Sprintf(format, ep.Address, port), } - server := ngxclient.UpstreamServer{ + servers = append(servers, server) + } + + return servers +} + +// ConvertStreamEndpoints converts a list of Endpoints into a list of NGINX Plus SDK StreamUpstreamServers. +func ConvertStreamEndpoints(eps []resolver.Endpoint) []ngxclient.StreamUpstreamServer { + servers := make([]ngxclient.StreamUpstreamServer, 0, len(eps)) + + for _, ep := range eps { + port, format := getPortAndIPFormat(ep) + + server := ngxclient.StreamUpstreamServer{ Server: fmt.Sprintf(format, ep.Address, port), } @@ -32,3 +41,18 @@ func ConvertEndpoints(eps []resolver.Endpoint) []ngxclient.UpstreamServer { return servers } + +func getPortAndIPFormat(ep resolver.Endpoint) (string, string) { + var port string + + if ep.Port != 0 { + port = fmt.Sprintf(":%d", ep.Port) + } + + format := "%s%s" + if ep.IPv6 { + format = "[%s]%s" + } + + return port, format +} diff --git a/internal/mode/static/nginx/config/convert_test.go b/internal/mode/static/nginx/config/convert_test.go index 6be41ccda6..68520dfd78 100644 --- a/internal/mode/static/nginx/config/convert_test.go +++ b/internal/mode/static/nginx/config/convert_test.go @@ -42,3 +42,37 @@ func TestConvertEndpoints(t *testing.T) { g := NewWithT(t) g.Expect(ConvertEndpoints(endpoints)).To(Equal(expUpstreams)) } + +func TestConvertStreamEndpoints(t *testing.T) { + t.Parallel() + endpoints := []resolver.Endpoint{ + { + Address: "1.2.3.4", + Port: 80, + }, + { + Address: "5.6.7.8", + Port: 0, + }, + { + Address: "2001:db8::1", + Port: 443, + IPv6: true, + }, + } + + expUpstreams := []ngxclient.StreamUpstreamServer{ + { + Server: "1.2.3.4:80", + }, + { + Server: "5.6.7.8", + }, + { + Server: "[2001:db8::1]:443", + }, + } + + g := NewWithT(t) + g.Expect(ConvertStreamEndpoints(endpoints)).To(Equal(expUpstreams)) +} diff --git a/internal/mode/static/nginx/config/http/config.go b/internal/mode/static/nginx/config/http/config.go index 24aecaa3e4..6d063dc8a7 100644 --- a/internal/mode/static/nginx/config/http/config.go +++ b/internal/mode/static/nginx/config/http/config.go @@ -82,9 +82,10 @@ const ( // Upstream holds all configuration for an HTTP upstream. type Upstream struct { - Name string - ZoneSize string // format: 512k, 1m - Servers []UpstreamServer + Name string + ZoneSize string // format: 512k, 1m + StateFile string + Servers []UpstreamServer } // UpstreamServer holds all configuration for an HTTP upstream server. diff --git a/internal/mode/static/nginx/config/stream/config.go b/internal/mode/static/nginx/config/stream/config.go index ddc215eea7..1202c1ec85 100644 --- a/internal/mode/static/nginx/config/stream/config.go +++ b/internal/mode/static/nginx/config/stream/config.go @@ -15,9 +15,10 @@ type Server struct { // Upstream holds all configuration for a stream upstream. type Upstream struct { - Name string - ZoneSize string // format: 512k, 1m - Servers []UpstreamServer + Name string + ZoneSize string // format: 512k, 1m + StateFile string + Servers []UpstreamServer } // UpstreamServer holds all configuration for a stream upstream server. diff --git a/internal/mode/static/nginx/config/upstreams.go b/internal/mode/static/nginx/config/upstreams.go index 88c66c47fd..51af6f4f4b 100644 --- a/internal/mode/static/nginx/config/upstreams.go +++ b/internal/mode/static/nginx/config/upstreams.go @@ -27,6 +27,8 @@ const ( ossZoneSizeStream = "512k" // plusZoneSize is the upstream zone size for nginx plus. plusZoneSizeStream = "1m" + // stateDir is the directory for storing state files. + stateDir = "/var/lib/nginx/state" ) func (g GeneratorImpl) executeUpstreams(conf dataplane.Configuration) []executeResult { @@ -64,9 +66,11 @@ func (g GeneratorImpl) createStreamUpstreams(upstreams []dataplane.Upstream) []s } func (g GeneratorImpl) createStreamUpstream(up dataplane.Upstream) stream.Upstream { + var stateFile string zoneSize := ossZoneSizeStream if g.plus { zoneSize = plusZoneSizeStream + stateFile = fmt.Sprintf("%s/%s.conf", stateDir, up.Name) } upstreamServers := make([]stream.UpstreamServer, len(up.Endpoints)) @@ -81,9 +85,10 @@ func (g GeneratorImpl) createStreamUpstream(up dataplane.Upstream) stream.Upstre } return stream.Upstream{ - Name: up.Name, - ZoneSize: zoneSize, - Servers: upstreamServers, + Name: up.Name, + ZoneSize: zoneSize, + StateFile: stateFile, + Servers: upstreamServers, } } @@ -101,15 +106,18 @@ func (g GeneratorImpl) createUpstreams(upstreams []dataplane.Upstream) []http.Up } func (g GeneratorImpl) createUpstream(up dataplane.Upstream) http.Upstream { + var stateFile string zoneSize := ossZoneSize if g.plus { zoneSize = plusZoneSize + stateFile = fmt.Sprintf("%s/%s.conf", stateDir, up.Name) } if len(up.Endpoints) == 0 { return http.Upstream{ - Name: up.Name, - ZoneSize: zoneSize, + Name: up.Name, + ZoneSize: zoneSize, + StateFile: stateFile, Servers: []http.UpstreamServer{ { Address: nginx503Server, @@ -130,9 +138,10 @@ func (g GeneratorImpl) createUpstream(up dataplane.Upstream) http.Upstream { } return http.Upstream{ - Name: up.Name, - ZoneSize: zoneSize, - Servers: upstreamServers, + Name: up.Name, + ZoneSize: zoneSize, + StateFile: stateFile, + Servers: upstreamServers, } } diff --git a/internal/mode/static/nginx/config/upstreams_template.go b/internal/mode/static/nginx/config/upstreams_template.go index a04915bec8..40d5740ad0 100644 --- a/internal/mode/static/nginx/config/upstreams_template.go +++ b/internal/mode/static/nginx/config/upstreams_template.go @@ -12,8 +12,13 @@ upstream {{ $u.Name }} { {{ if $u.ZoneSize -}} zone {{ $u.Name }} {{ $u.ZoneSize }}; {{ end -}} - {{ range $server := $u.Servers }} + + {{- if $u.StateFile }} + state {{ $u.StateFile }}; + {{- else }} + {{ range $server := $u.Servers }} server {{ $server.Address }}; + {{- end }} {{- end }} } {{ end -}} diff --git a/internal/mode/static/nginx/config/upstreams_test.go b/internal/mode/static/nginx/config/upstreams_test.go index 5b3a8268a3..f2e5b1071b 100644 --- a/internal/mode/static/nginx/config/upstreams_test.go +++ b/internal/mode/static/nginx/config/upstreams_test.go @@ -289,29 +289,60 @@ func TestCreateUpstreamPlus(t *testing.T) { t.Parallel() gen := GeneratorImpl{plus: true} - stateUpstream := dataplane.Upstream{ - Name: "multiple-endpoints", - Endpoints: []resolver.Endpoint{ - { - Address: "10.0.0.1", - Port: 80, + tests := []struct { + msg string + stateUpstream dataplane.Upstream + expectedUpstream http.Upstream + }{ + { + msg: "with endpoints", + stateUpstream: dataplane.Upstream{ + Name: "endpoints", + Endpoints: []resolver.Endpoint{ + { + Address: "10.0.0.1", + Port: 80, + }, + }, + }, + expectedUpstream: http.Upstream{ + Name: "endpoints", + ZoneSize: plusZoneSize, + StateFile: stateDir + "/endpoints.conf", + Servers: []http.UpstreamServer{ + { + Address: "10.0.0.1:80", + }, + }, }, }, - } - expectedUpstream := http.Upstream{ - Name: "multiple-endpoints", - ZoneSize: plusZoneSize, - Servers: []http.UpstreamServer{ - { - Address: "10.0.0.1:80", + { + msg: "no endpoints", + stateUpstream: dataplane.Upstream{ + Name: "no-endpoints", + Endpoints: []resolver.Endpoint{}, + }, + expectedUpstream: http.Upstream{ + Name: "no-endpoints", + ZoneSize: plusZoneSize, + StateFile: stateDir + "/no-endpoints.conf", + Servers: []http.UpstreamServer{ + { + Address: nginx503Server, + }, + }, }, }, } - result := gen.createUpstream(stateUpstream) - - g := NewWithT(t) - g.Expect(result).To(Equal(expectedUpstream)) + for _, test := range tests { + t.Run(test.msg, func(t *testing.T) { + t.Parallel() + g := NewWithT(t) + result := gen.createUpstream(test.stateUpstream) + g.Expect(result).To(Equal(test.expectedUpstream)) + }) + } } func TestExecuteStreamUpstreams(t *testing.T) { @@ -491,8 +522,9 @@ func TestCreateStreamUpstreamPlus(t *testing.T) { }, } expectedUpstream := stream.Upstream{ - Name: "multiple-endpoints", - ZoneSize: plusZoneSize, + Name: "multiple-endpoints", + ZoneSize: plusZoneSize, + StateFile: stateDir + "/multiple-endpoints.conf", Servers: []stream.UpstreamServer{ { Address: "10.0.0.1:80", diff --git a/internal/mode/static/nginx/runtime/manager.go b/internal/mode/static/nginx/runtime/manager.go index 45d24dbb75..afa641645f 100644 --- a/internal/mode/static/nginx/runtime/manager.go +++ b/internal/mode/static/nginx/runtime/manager.go @@ -47,6 +47,16 @@ type NginxPlusClient interface { err error, ) GetUpstreams() (*ngxclient.Upstreams, error) + UpdateStreamServers( + upstream string, + servers []ngxclient.StreamUpstreamServer, + ) ( + added []ngxclient.StreamUpstreamServer, + deleted []ngxclient.StreamUpstreamServer, + updated []ngxclient.StreamUpstreamServer, + err error, + ) + GetStreamUpstreams() (*ngxclient.StreamUpstreams, error) } //counterfeiter:generate . Manager @@ -57,12 +67,15 @@ type Manager interface { Reload(ctx context.Context, configVersion int) error // IsPlus returns whether or not we are running NGINX plus. IsPlus() bool - // UpdateHTTPServers uses the NGINX Plus API to update HTTP servers. + // GetUpstreams uses the NGINX Plus API to get the upstreams. + // Only usable if running NGINX Plus. + GetUpstreams() (ngxclient.Upstreams, ngxclient.StreamUpstreams, error) + // UpdateHTTPServers uses the NGINX Plus API to update HTTP upstream servers. // Only usable if running NGINX Plus. UpdateHTTPServers(string, []ngxclient.UpstreamServer) error - // GetUpstreams uses the NGINX Plus API to get the upstreams. + // UpdateStreamServers uses the NGINX Plus API to update stream upstream servers. // Only usable if running NGINX Plus. - GetUpstreams() (ngxclient.Upstreams, error) + UpdateStreamServers(string, []ngxclient.StreamUpstreamServer) error } // MetricsCollector is an interface for the metrics of the NGINX runtime manager. @@ -143,6 +156,34 @@ func (m *ManagerImpl) Reload(ctx context.Context, configVersion int) error { return nil } +// GetUpstreams uses the NGINX Plus API to get the upstreams. +// Only usable if running NGINX Plus. +func (m *ManagerImpl) GetUpstreams() (ngxclient.Upstreams, ngxclient.StreamUpstreams, error) { + if !m.IsPlus() { + panic("cannot get upstream servers: NGINX Plus not enabled") + } + + upstreams, err := m.ngxPlusClient.GetUpstreams() + if err != nil { + return nil, nil, err + } + + if upstreams == nil { + return nil, nil, errors.New("GET upstreams returned nil value") + } + + streamUpstreams, err := m.ngxPlusClient.GetStreamUpstreams() + if err != nil { + return nil, nil, err + } + + if streamUpstreams == nil { + return nil, nil, errors.New("GET stream upstreams returned nil value") + } + + return *upstreams, *streamUpstreams, nil +} + // UpdateHTTPServers uses the NGINX Plus API to update HTTP upstream servers. // Only usable if running NGINX Plus. func (m *ManagerImpl) UpdateHTTPServers(upstream string, servers []ngxclient.UpstreamServer) error { @@ -158,23 +199,19 @@ func (m *ManagerImpl) UpdateHTTPServers(upstream string, servers []ngxclient.Ups return err } -// GetUpstreams uses the NGINX Plus API to get the upstreams. +// UpdateStreamServers uses the NGINX Plus API to update stream upstream servers. // Only usable if running NGINX Plus. -func (m *ManagerImpl) GetUpstreams() (ngxclient.Upstreams, error) { +func (m *ManagerImpl) UpdateStreamServers(upstream string, servers []ngxclient.StreamUpstreamServer) error { if !m.IsPlus() { - panic("cannot get HTTP upstream servers: NGINX Plus not enabled") - } - - upstreams, err := m.ngxPlusClient.GetUpstreams() - if err != nil { - return nil, err + panic("cannot update stream upstream servers: NGINX Plus not enabled") } - if upstreams == nil { - return nil, errors.New("GET upstreams returned nil value") - } + added, deleted, updated, err := m.ngxPlusClient.UpdateStreamServers(upstream, servers) + m.logger.V(1).Info("Added stream upstream servers", "count", len(added)) + m.logger.V(1).Info("Deleted stream upstream servers", "count", len(deleted)) + m.logger.V(1).Info("Updated stream upstream servers", "count", len(updated)) - return *upstreams, nil + return err } //counterfeiter:generate . ProcessHandler diff --git a/internal/mode/static/nginx/runtime/manager_test.go b/internal/mode/static/nginx/runtime/manager_test.go index 15eb498a7f..036731e1ea 100644 --- a/internal/mode/static/nginx/runtime/manager_test.go +++ b/internal/mode/static/nginx/runtime/manager_test.go @@ -27,11 +27,12 @@ var _ = Describe("NGINX Runtime Manager", func() { }) var ( - err error - manager runtime.Manager - upstreamServers []ngxclient.UpstreamServer - ngxPlusClient *runtimefakes.FakeNginxPlusClient - process *runtimefakes.FakeProcessHandler + err error + manager runtime.Manager + upstreamServers []ngxclient.UpstreamServer + streamUpstreamServers []ngxclient.StreamUpstreamServer + ngxPlusClient *runtimefakes.FakeNginxPlusClient + process *runtimefakes.FakeProcessHandler metrics *runtimefakes.FakeMetricsCollector verifyClient *runtimefakes.FakeVerifyClient @@ -41,6 +42,9 @@ var _ = Describe("NGINX Runtime Manager", func() { upstreamServers = []ngxclient.UpstreamServer{ {}, } + streamUpstreamServers = []ngxclient.StreamUpstreamServer{ + {}, + } }) Context("Reload", func() { @@ -150,11 +154,16 @@ var _ = Describe("NGINX Runtime Manager", func() { Expect(manager.UpdateHTTPServers("test", upstreamServers)).To(Succeed()) }) + It("successfully updates stream server upstream", func() { + Expect(manager.UpdateStreamServers("test", streamUpstreamServers)).To(Succeed()) + }) + It("returns no upstreams from NGINX Plus API when upstreams are nil", func() { - upstreams, err := manager.GetUpstreams() + upstreams, streamUpstreams, err := manager.GetUpstreams() Expect(err).To(HaveOccurred()) Expect(upstreams).To(BeEmpty()) + Expect(streamUpstreams).To(BeEmpty()) }) It("successfully returns server upstreams", func() { @@ -177,22 +186,77 @@ var _ = Describe("NGINX Runtime Manager", func() { }, } + expStreamUpstreams := ngxclient.StreamUpstreams{ + "upstream1": { + Zone: "zone1", + Peers: []ngxclient.StreamPeer{ + {ID: 1, Name: "peer1-name"}, + }, + Zombies: 2, + }, + "upstream2": { + Zone: "zone2", + Peers: []ngxclient.StreamPeer{ + {ID: 2, Name: "peer2-name"}, + }, + Zombies: 1, + }, + } + ngxPlusClient.GetUpstreamsReturns(&expUpstreams, nil) + ngxPlusClient.GetStreamUpstreamsReturns(&expStreamUpstreams, nil) - upstreams, err := manager.GetUpstreams() + upstreams, streamUpstreams, err := manager.GetUpstreams() Expect(err).NotTo(HaveOccurred()) Expect(expUpstreams).To(Equal(upstreams)) + Expect(expStreamUpstreams).To(Equal(streamUpstreams)) }) It("returns an error when GetUpstreams fails", func() { ngxPlusClient.GetUpstreamsReturns(nil, errors.New("failed to get upstreams")) - upstreams, err := manager.GetUpstreams() + upstreams, streamUpstreams, err := manager.GetUpstreams() + + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("failed to get upstreams")) + Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) + }) + + It("returns an error when GetUpstreams returns nil", func() { + ngxPlusClient.GetUpstreamsReturns(nil, nil) + + upstreams, streamUpstreams, err := manager.GetUpstreams() + + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("GET upstreams returned nil value")) + Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) + }) + + It("returns an error when GetStreamUpstreams fails", func() { + ngxPlusClient.GetUpstreamsReturns(&ngxclient.Upstreams{}, nil) + ngxPlusClient.GetStreamUpstreamsReturns(nil, errors.New("failed to get upstreams")) + + upstreams, streamUpstreams, err := manager.GetUpstreams() Expect(err).To(HaveOccurred()) Expect(err).To(MatchError("failed to get upstreams")) Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) + }) + + It("returns an error when GetStreamUpstreams returns nil", func() { + ngxPlusClient.GetUpstreamsReturns(&ngxclient.Upstreams{}, nil) + ngxPlusClient.GetStreamUpstreamsReturns(nil, nil) + + upstreams, streamUpstreams, err := manager.GetUpstreams() + + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError("GET stream upstreams returned nil value")) + Expect(upstreams).To(BeNil()) + Expect(streamUpstreams).To(BeNil()) }) }) @@ -202,6 +266,15 @@ var _ = Describe("NGINX Runtime Manager", func() { manager = runtime.NewManagerImpl(ngxPlusClient, nil, zap.New(), nil, nil) }) + It("should panic when fetching upstream servers", func() { + upstreams := func() { + _, _, err = manager.GetUpstreams() + } + + Expect(upstreams).To(Panic()) + Expect(err).ToNot(HaveOccurred()) + }) + It("should panic when updating HTTP upstream servers", func() { updateServers := func() { err = manager.UpdateHTTPServers("test", upstreamServers) @@ -211,12 +284,12 @@ var _ = Describe("NGINX Runtime Manager", func() { Expect(err).ToNot(HaveOccurred()) }) - It("should panic when fetching HTTP upstream servers", func() { - upstreams := func() { - _, err = manager.GetUpstreams() + It("should panic when updating stream upstream servers", func() { + updateServers := func() { + err = manager.UpdateStreamServers("test", streamUpstreamServers) } - Expect(upstreams).To(Panic()) + Expect(updateServers).To(Panic()) Expect(err).ToNot(HaveOccurred()) }) }) diff --git a/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go b/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go index 2538e32de3..ea7504a762 100644 --- a/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go +++ b/internal/mode/static/nginx/runtime/runtimefakes/fake_manager.go @@ -10,17 +10,19 @@ import ( ) type FakeManager struct { - GetUpstreamsStub func() (client.Upstreams, error) + GetUpstreamsStub func() (client.Upstreams, client.StreamUpstreams, error) getUpstreamsMutex sync.RWMutex getUpstreamsArgsForCall []struct { } getUpstreamsReturns struct { result1 client.Upstreams - result2 error + result2 client.StreamUpstreams + result3 error } getUpstreamsReturnsOnCall map[int]struct { result1 client.Upstreams - result2 error + result2 client.StreamUpstreams + result3 error } IsPlusStub func() bool isPlusMutex sync.RWMutex @@ -56,11 +58,23 @@ type FakeManager struct { updateHTTPServersReturnsOnCall map[int]struct { result1 error } + UpdateStreamServersStub func(string, []client.StreamUpstreamServer) error + updateStreamServersMutex sync.RWMutex + updateStreamServersArgsForCall []struct { + arg1 string + arg2 []client.StreamUpstreamServer + } + updateStreamServersReturns struct { + result1 error + } + updateStreamServersReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } -func (fake *FakeManager) GetUpstreams() (client.Upstreams, error) { +func (fake *FakeManager) GetUpstreams() (client.Upstreams, client.StreamUpstreams, error) { fake.getUpstreamsMutex.Lock() ret, specificReturn := fake.getUpstreamsReturnsOnCall[len(fake.getUpstreamsArgsForCall)] fake.getUpstreamsArgsForCall = append(fake.getUpstreamsArgsForCall, struct { @@ -73,9 +87,9 @@ func (fake *FakeManager) GetUpstreams() (client.Upstreams, error) { return stub() } if specificReturn { - return ret.result1, ret.result2 + return ret.result1, ret.result2, ret.result3 } - return fakeReturns.result1, fakeReturns.result2 + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 } func (fake *FakeManager) GetUpstreamsCallCount() int { @@ -84,36 +98,39 @@ func (fake *FakeManager) GetUpstreamsCallCount() int { return len(fake.getUpstreamsArgsForCall) } -func (fake *FakeManager) GetUpstreamsCalls(stub func() (client.Upstreams, error)) { +func (fake *FakeManager) GetUpstreamsCalls(stub func() (client.Upstreams, client.StreamUpstreams, error)) { fake.getUpstreamsMutex.Lock() defer fake.getUpstreamsMutex.Unlock() fake.GetUpstreamsStub = stub } -func (fake *FakeManager) GetUpstreamsReturns(result1 client.Upstreams, result2 error) { +func (fake *FakeManager) GetUpstreamsReturns(result1 client.Upstreams, result2 client.StreamUpstreams, result3 error) { fake.getUpstreamsMutex.Lock() defer fake.getUpstreamsMutex.Unlock() fake.GetUpstreamsStub = nil fake.getUpstreamsReturns = struct { result1 client.Upstreams - result2 error - }{result1, result2} + result2 client.StreamUpstreams + result3 error + }{result1, result2, result3} } -func (fake *FakeManager) GetUpstreamsReturnsOnCall(i int, result1 client.Upstreams, result2 error) { +func (fake *FakeManager) GetUpstreamsReturnsOnCall(i int, result1 client.Upstreams, result2 client.StreamUpstreams, result3 error) { fake.getUpstreamsMutex.Lock() defer fake.getUpstreamsMutex.Unlock() fake.GetUpstreamsStub = nil if fake.getUpstreamsReturnsOnCall == nil { fake.getUpstreamsReturnsOnCall = make(map[int]struct { result1 client.Upstreams - result2 error + result2 client.StreamUpstreams + result3 error }) } fake.getUpstreamsReturnsOnCall[i] = struct { result1 client.Upstreams - result2 error - }{result1, result2} + result2 client.StreamUpstreams + result3 error + }{result1, result2, result3} } func (fake *FakeManager) IsPlus() bool { @@ -298,6 +315,73 @@ func (fake *FakeManager) UpdateHTTPServersReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeManager) UpdateStreamServers(arg1 string, arg2 []client.StreamUpstreamServer) error { + var arg2Copy []client.StreamUpstreamServer + if arg2 != nil { + arg2Copy = make([]client.StreamUpstreamServer, len(arg2)) + copy(arg2Copy, arg2) + } + fake.updateStreamServersMutex.Lock() + ret, specificReturn := fake.updateStreamServersReturnsOnCall[len(fake.updateStreamServersArgsForCall)] + fake.updateStreamServersArgsForCall = append(fake.updateStreamServersArgsForCall, struct { + arg1 string + arg2 []client.StreamUpstreamServer + }{arg1, arg2Copy}) + stub := fake.UpdateStreamServersStub + fakeReturns := fake.updateStreamServersReturns + fake.recordInvocation("UpdateStreamServers", []interface{}{arg1, arg2Copy}) + fake.updateStreamServersMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeManager) UpdateStreamServersCallCount() int { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + return len(fake.updateStreamServersArgsForCall) +} + +func (fake *FakeManager) UpdateStreamServersCalls(stub func(string, []client.StreamUpstreamServer) error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = stub +} + +func (fake *FakeManager) UpdateStreamServersArgsForCall(i int) (string, []client.StreamUpstreamServer) { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + argsForCall := fake.updateStreamServersArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeManager) UpdateStreamServersReturns(result1 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + fake.updateStreamServersReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeManager) UpdateStreamServersReturnsOnCall(i int, result1 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + if fake.updateStreamServersReturnsOnCall == nil { + fake.updateStreamServersReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.updateStreamServersReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeManager) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() @@ -309,6 +393,8 @@ func (fake *FakeManager) Invocations() map[string][][]interface{} { defer fake.reloadMutex.RUnlock() fake.updateHTTPServersMutex.RLock() defer fake.updateHTTPServersMutex.RUnlock() + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value diff --git a/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go b/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go index 3ea431d29b..8001f7f8a7 100644 --- a/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go +++ b/internal/mode/static/nginx/runtime/runtimefakes/fake_nginx_plus_client.go @@ -9,6 +9,18 @@ import ( ) type FakeNginxPlusClient struct { + GetStreamUpstreamsStub func() (*client.StreamUpstreams, error) + getStreamUpstreamsMutex sync.RWMutex + getStreamUpstreamsArgsForCall []struct { + } + getStreamUpstreamsReturns struct { + result1 *client.StreamUpstreams + result2 error + } + getStreamUpstreamsReturnsOnCall map[int]struct { + result1 *client.StreamUpstreams + result2 error + } GetUpstreamsStub func() (*client.Upstreams, error) getUpstreamsMutex sync.RWMutex getUpstreamsArgsForCall []struct { @@ -39,10 +51,84 @@ type FakeNginxPlusClient struct { result3 []client.UpstreamServer result4 error } + UpdateStreamServersStub func(string, []client.StreamUpstreamServer) ([]client.StreamUpstreamServer, []client.StreamUpstreamServer, []client.StreamUpstreamServer, error) + updateStreamServersMutex sync.RWMutex + updateStreamServersArgsForCall []struct { + arg1 string + arg2 []client.StreamUpstreamServer + } + updateStreamServersReturns struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + } + updateStreamServersReturnsOnCall map[int]struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } +func (fake *FakeNginxPlusClient) GetStreamUpstreams() (*client.StreamUpstreams, error) { + fake.getStreamUpstreamsMutex.Lock() + ret, specificReturn := fake.getStreamUpstreamsReturnsOnCall[len(fake.getStreamUpstreamsArgsForCall)] + fake.getStreamUpstreamsArgsForCall = append(fake.getStreamUpstreamsArgsForCall, struct { + }{}) + stub := fake.GetStreamUpstreamsStub + fakeReturns := fake.getStreamUpstreamsReturns + fake.recordInvocation("GetStreamUpstreams", []interface{}{}) + fake.getStreamUpstreamsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsCallCount() int { + fake.getStreamUpstreamsMutex.RLock() + defer fake.getStreamUpstreamsMutex.RUnlock() + return len(fake.getStreamUpstreamsArgsForCall) +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsCalls(stub func() (*client.StreamUpstreams, error)) { + fake.getStreamUpstreamsMutex.Lock() + defer fake.getStreamUpstreamsMutex.Unlock() + fake.GetStreamUpstreamsStub = stub +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsReturns(result1 *client.StreamUpstreams, result2 error) { + fake.getStreamUpstreamsMutex.Lock() + defer fake.getStreamUpstreamsMutex.Unlock() + fake.GetStreamUpstreamsStub = nil + fake.getStreamUpstreamsReturns = struct { + result1 *client.StreamUpstreams + result2 error + }{result1, result2} +} + +func (fake *FakeNginxPlusClient) GetStreamUpstreamsReturnsOnCall(i int, result1 *client.StreamUpstreams, result2 error) { + fake.getStreamUpstreamsMutex.Lock() + defer fake.getStreamUpstreamsMutex.Unlock() + fake.GetStreamUpstreamsStub = nil + if fake.getStreamUpstreamsReturnsOnCall == nil { + fake.getStreamUpstreamsReturnsOnCall = make(map[int]struct { + result1 *client.StreamUpstreams + result2 error + }) + } + fake.getStreamUpstreamsReturnsOnCall[i] = struct { + result1 *client.StreamUpstreams + result2 error + }{result1, result2} +} + func (fake *FakeNginxPlusClient) GetUpstreams() (*client.Upstreams, error) { fake.getUpstreamsMutex.Lock() ret, specificReturn := fake.getUpstreamsReturnsOnCall[len(fake.getUpstreamsArgsForCall)] @@ -175,13 +261,93 @@ func (fake *FakeNginxPlusClient) UpdateHTTPServersReturnsOnCall(i int, result1 [ }{result1, result2, result3, result4} } +func (fake *FakeNginxPlusClient) UpdateStreamServers(arg1 string, arg2 []client.StreamUpstreamServer) ([]client.StreamUpstreamServer, []client.StreamUpstreamServer, []client.StreamUpstreamServer, error) { + var arg2Copy []client.StreamUpstreamServer + if arg2 != nil { + arg2Copy = make([]client.StreamUpstreamServer, len(arg2)) + copy(arg2Copy, arg2) + } + fake.updateStreamServersMutex.Lock() + ret, specificReturn := fake.updateStreamServersReturnsOnCall[len(fake.updateStreamServersArgsForCall)] + fake.updateStreamServersArgsForCall = append(fake.updateStreamServersArgsForCall, struct { + arg1 string + arg2 []client.StreamUpstreamServer + }{arg1, arg2Copy}) + stub := fake.UpdateStreamServersStub + fakeReturns := fake.updateStreamServersReturns + fake.recordInvocation("UpdateStreamServers", []interface{}{arg1, arg2Copy}) + fake.updateStreamServersMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3, ret.result4 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3, fakeReturns.result4 +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersCallCount() int { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + return len(fake.updateStreamServersArgsForCall) +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersCalls(stub func(string, []client.StreamUpstreamServer) ([]client.StreamUpstreamServer, []client.StreamUpstreamServer, []client.StreamUpstreamServer, error)) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = stub +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersArgsForCall(i int) (string, []client.StreamUpstreamServer) { + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() + argsForCall := fake.updateStreamServersArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersReturns(result1 []client.StreamUpstreamServer, result2 []client.StreamUpstreamServer, result3 []client.StreamUpstreamServer, result4 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + fake.updateStreamServersReturns = struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + }{result1, result2, result3, result4} +} + +func (fake *FakeNginxPlusClient) UpdateStreamServersReturnsOnCall(i int, result1 []client.StreamUpstreamServer, result2 []client.StreamUpstreamServer, result3 []client.StreamUpstreamServer, result4 error) { + fake.updateStreamServersMutex.Lock() + defer fake.updateStreamServersMutex.Unlock() + fake.UpdateStreamServersStub = nil + if fake.updateStreamServersReturnsOnCall == nil { + fake.updateStreamServersReturnsOnCall = make(map[int]struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + }) + } + fake.updateStreamServersReturnsOnCall[i] = struct { + result1 []client.StreamUpstreamServer + result2 []client.StreamUpstreamServer + result3 []client.StreamUpstreamServer + result4 error + }{result1, result2, result3, result4} +} + func (fake *FakeNginxPlusClient) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() + fake.getStreamUpstreamsMutex.RLock() + defer fake.getStreamUpstreamsMutex.RUnlock() fake.getUpstreamsMutex.RLock() defer fake.getUpstreamsMutex.RUnlock() fake.updateHTTPServersMutex.RLock() defer fake.updateHTTPServersMutex.RUnlock() + fake.updateStreamServersMutex.RLock() + defer fake.updateStreamServersMutex.RUnlock() copiedInvocations := map[string][][]interface{}{} for key, value := range fake.invocations { copiedInvocations[key] = value