Skip to content

Commit

Permalink
Use state file for updating N+ upstreams (#2897)
Browse files Browse the repository at this point in the history
Problem: When splitting the data and control plane, we want to be able to maintain the ability to dynamically update upstreams without a reload if using NGINX Plus. With the current process, we write the upstreams into the nginx conf but don't reload, then call the API to actually do the update. Writing into the conf will trigger a reload from the agent once we split, however.

There were also two bugs:
- if metrics were disabled, the nginx plus client was not initialized, preventing API calls from occuring and instead a reload occurred
- stream upstreams were not updated with the API

Solution: Don't write the upstream servers into the nginx conf anymore when using NGINX Plus. Instead, utilize the `state` file option that the NGINX Plus API will populate with the upstream servers. This way we can just call the API and don't unintentionally reload by writing servers into the conf.

Also added support for updating stream upstreams, and fixed the initialization bug.
  • Loading branch information
sjberman authored Dec 13, 2024
1 parent 0cbb726 commit 8e2e2d8
Show file tree
Hide file tree
Showing 14 changed files with 725 additions and 189 deletions.
148 changes: 87 additions & 61 deletions internal/mode/static/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package static

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -182,11 +183,11 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log

h.setLatestConfiguration(&cfg)

err = h.updateUpstreamServers(
ctx,
logger,
cfg,
)
if h.cfg.plus {
err = h.updateUpstreamServers(cfg)
} else {
err = h.updateNginxConf(ctx, cfg)
}
case state.ClusterStateChange:
h.version++
cfg := dataplane.BuildConfiguration(ctx, gr, h.cfg.serviceResolver, h.version)
Expand All @@ -198,10 +199,7 @@ func (h *eventHandlerImpl) HandleEventBatch(ctx context.Context, logger logr.Log

h.setLatestConfiguration(&cfg)

err = h.updateNginxConf(
ctx,
cfg,
)
err = h.updateNginxConf(ctx, cfg)
}

var nginxReloadRes status.NginxReloadResult
Expand Down Expand Up @@ -306,7 +304,10 @@ func (h *eventHandlerImpl) parseAndCaptureEvent(ctx context.Context, logger logr
}

// updateNginxConf updates nginx conf files and reloads nginx.
func (h *eventHandlerImpl) updateNginxConf(ctx context.Context, conf dataplane.Configuration) error {
func (h *eventHandlerImpl) updateNginxConf(
ctx context.Context,
conf dataplane.Configuration,
) error {
files := h.cfg.generator.Generate(conf)
if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil {
return fmt.Errorf("failed to replace NGINX configuration files: %w", err)
Expand All @@ -316,89 +317,114 @@ func (h *eventHandlerImpl) updateNginxConf(ctx context.Context, conf dataplane.C
return fmt.Errorf("failed to reload NGINX: %w", err)
}

// If using NGINX Plus, update upstream servers using the API.
if err := h.updateUpstreamServers(conf); err != nil {
return fmt.Errorf("failed to update upstream servers: %w", err)
}

return nil
}

// updateUpstreamServers is called only when endpoints have changed. It updates nginx conf files and then:
// - if using NGINX Plus, determines which servers have changed and uses the N+ API to update them;
// - otherwise if not using NGINX Plus, or an error was returned from the API, reloads nginx.
func (h *eventHandlerImpl) updateUpstreamServers(
ctx context.Context,
logger logr.Logger,
conf dataplane.Configuration,
) error {
isPlus := h.cfg.nginxRuntimeMgr.IsPlus()

files := h.cfg.generator.Generate(conf)
if err := h.cfg.nginxFileMgr.ReplaceFiles(files); err != nil {
return fmt.Errorf("failed to replace NGINX configuration files: %w", err)
// updateUpstreamServers determines which servers have changed and uses the NGINX Plus API to update them.
// Only applicable when using NGINX Plus.
func (h *eventHandlerImpl) updateUpstreamServers(conf dataplane.Configuration) error {
if !h.cfg.plus {
return nil
}

reload := func() error {
if err := h.cfg.nginxRuntimeMgr.Reload(ctx, conf.Version); err != nil {
return fmt.Errorf("failed to reload NGINX: %w", err)
}
prevUpstreams, prevStreamUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams()
if err != nil {
return fmt.Errorf("failed to get upstreams from API: %w", err)
}

return nil
type upstream struct {
name string
servers []ngxclient.UpstreamServer
}
var upstreams []upstream

if isPlus {
type upstream struct {
name string
servers []ngxclient.UpstreamServer
for _, u := range conf.Upstreams {
confUpstream := upstream{
name: u.Name,
servers: ngxConfig.ConvertEndpoints(u.Endpoints),
}
var upstreams []upstream

prevUpstreams, err := h.cfg.nginxRuntimeMgr.GetUpstreams()
if err != nil {
logger.Error(err, "failed to get upstreams from API, reloading configuration instead")
return reload()
if u, ok := prevUpstreams[confUpstream.name]; ok {
if !serversEqual(confUpstream.servers, u.Peers) {
upstreams = append(upstreams, confUpstream)
}
}
}

for _, u := range conf.Upstreams {
confUpstream := upstream{
name: u.Name,
servers: ngxConfig.ConvertEndpoints(u.Endpoints),
}
type streamUpstream struct {
name string
servers []ngxclient.StreamUpstreamServer
}
var streamUpstreams []streamUpstream

if u, ok := prevUpstreams[confUpstream.name]; ok {
if !serversEqual(confUpstream.servers, u.Peers) {
upstreams = append(upstreams, confUpstream)
}
}
for _, u := range conf.StreamUpstreams {
confUpstream := streamUpstream{
name: u.Name,
servers: ngxConfig.ConvertStreamEndpoints(u.Endpoints),
}

var reloadPlus bool
for _, upstream := range upstreams {
if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil {
logger.Error(
err, "couldn't update upstream via the API, reloading configuration instead",
"upstreamName", upstream.name,
)
reloadPlus = true
if u, ok := prevStreamUpstreams[confUpstream.name]; ok {
if !serversEqual(confUpstream.servers, u.Peers) {
streamUpstreams = append(streamUpstreams, confUpstream)
}
}
}

if !reloadPlus {
return nil
var updateErr error
for _, upstream := range upstreams {
if err := h.cfg.nginxRuntimeMgr.UpdateHTTPServers(upstream.name, upstream.servers); err != nil {
updateErr = errors.Join(updateErr, fmt.Errorf(
"couldn't update upstream %q via the API: %w", upstream.name, err))
}
}

return reload()
for _, upstream := range streamUpstreams {
if err := h.cfg.nginxRuntimeMgr.UpdateStreamServers(upstream.name, upstream.servers); err != nil {
updateErr = errors.Join(updateErr, fmt.Errorf(
"couldn't update stream upstream %q via the API: %w", upstream.name, err))
}
}

return updateErr
}

func serversEqual(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer) bool {
// serversEqual accepts lists of either UpstreamServer/Peer or StreamUpstreamServer/StreamPeer and determines
// if the server names within these lists are equal.
func serversEqual[
upstreamServer ngxclient.UpstreamServer | ngxclient.StreamUpstreamServer,
peer ngxclient.Peer | ngxclient.StreamPeer,
](newServers []upstreamServer, oldServers []peer) bool {
if len(newServers) != len(oldServers) {
return false
}

getServerVal := func(T any) string {
var server string
switch t := T.(type) {
case ngxclient.UpstreamServer:
server = t.Server
case ngxclient.StreamUpstreamServer:
server = t.Server
case ngxclient.Peer:
server = t.Server
case ngxclient.StreamPeer:
server = t.Server
}
return server
}

diff := make(map[string]struct{}, len(newServers))
for _, s := range newServers {
diff[s.Server] = struct{}{}
diff[getServerVal(s)] = struct{}{}
}

for _, s := range oldServers {
if _, ok := diff[s.Server]; !ok {
if _, ok := diff[getServerVal(s)]; !ok {
return false
}
}
Expand Down
122 changes: 81 additions & 41 deletions internal/mode/static/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,20 +423,29 @@ var _ = Describe("eventHandler", func() {
},
},
}
fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil)

streamUpstreams := ngxclient.StreamUpstreams{
"two": ngxclient.StreamUpstream{
Peers: []ngxclient.StreamPeer{
{Server: "server2"},
},
},
}

fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil)
})

When("running NGINX Plus", func() {
It("should call the NGINX Plus API", func() {
fakeNginxRuntimeMgr.IsPlusReturns(true)
handler.cfg.plus = true

handler.HandleEventBatch(context.Background(), ctlrZap.New(), batch)

dcfg := dataplane.GetDefaultConfiguration(&graph.Graph{}, 1)
Expect(helpers.Diff(handler.GetLatestConfiguration(), &dcfg)).To(BeEmpty())

Expect(fakeGenerator.GenerateCallCount()).To(Equal(1))
Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(1))
Expect(fakeGenerator.GenerateCallCount()).To(Equal(0))
Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(0))
Expect(fakeNginxRuntimeMgr.GetUpstreamsCallCount()).To(Equal(1))
})
})
Expand All @@ -463,19 +472,11 @@ var _ = Describe("eventHandler", func() {
Name: "one",
},
},
}

type callCounts struct {
generate int
update int
reload int
}

assertCallCounts := func(cc callCounts) {
Expect(fakeGenerator.GenerateCallCount()).To(Equal(cc.generate))
Expect(fakeNginxFileMgr.ReplaceFilesCallCount()).To(Equal(cc.generate))
Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(cc.update))
Expect(fakeNginxRuntimeMgr.ReloadCallCount()).To(Equal(cc.reload))
StreamUpstreams: []dataplane.Upstream{
{
Name: "two",
},
},
}

BeforeEach(func() {
Expand All @@ -486,47 +487,49 @@ var _ = Describe("eventHandler", func() {
},
},
}
fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, nil)

streamUpstreams := ngxclient.StreamUpstreams{
"two": ngxclient.StreamUpstream{
Peers: []ngxclient.StreamPeer{
{Server: "server2"},
},
},
}

fakeNginxRuntimeMgr.GetUpstreamsReturns(upstreams, streamUpstreams, nil)
})

When("running NGINX Plus", func() {
BeforeEach(func() {
fakeNginxRuntimeMgr.IsPlusReturns(true)
handler.cfg.plus = true
})

It("should update servers using the NGINX Plus API", func() {
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())

assertCallCounts(callCounts{generate: 1, update: 1, reload: 0})
Expect(handler.updateUpstreamServers(conf)).To(Succeed())
Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(1))
})

It("should reload when GET API returns an error", func() {
fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, errors.New("error"))
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())

assertCallCounts(callCounts{generate: 1, update: 0, reload: 1})
It("should return error when GET API returns an error", func() {
fakeNginxRuntimeMgr.GetUpstreamsReturns(nil, nil, errors.New("error"))
Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed())
})

It("should reload when POST API returns an error", func() {
It("should return error when UpdateHTTPServers API returns an error", func() {
fakeNginxRuntimeMgr.UpdateHTTPServersReturns(errors.New("error"))
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())
Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed())
})

assertCallCounts(callCounts{generate: 1, update: 1, reload: 1})
It("should return error when UpdateStreamServers API returns an error", func() {
fakeNginxRuntimeMgr.UpdateStreamServersReturns(errors.New("error"))
Expect(handler.updateUpstreamServers(conf)).ToNot(Succeed())
})
})

When("not running NGINX Plus", func() {
It("should update servers by reloading", func() {
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).To(Succeed())

assertCallCounts(callCounts{generate: 1, update: 0, reload: 1})
})
It("should not do anything", func() {
Expect(handler.updateUpstreamServers(conf)).To(Succeed())

It("should return an error when reloading fails", func() {
fakeNginxRuntimeMgr.ReloadReturns(errors.New("error"))
Expect(handler.updateUpstreamServers(context.Background(), ctlrZap.New(), conf)).ToNot(Succeed())

assertCallCounts(callCounts{generate: 1, update: 0, reload: 1})
Expect(fakeNginxRuntimeMgr.UpdateHTTPServersCallCount()).To(Equal(0))
})
})
})
Expand Down Expand Up @@ -612,7 +615,7 @@ var _ = Describe("eventHandler", func() {
})

var _ = Describe("serversEqual", func() {
DescribeTable("determines if server lists are equal",
DescribeTable("determines if HTTP server lists are equal",
func(newServers []ngxclient.UpstreamServer, oldServers []ngxclient.Peer, equal bool) {
Expect(serversEqual(newServers, oldServers)).To(Equal(equal))
},
Expand Down Expand Up @@ -649,6 +652,43 @@ var _ = Describe("serversEqual", func() {
true,
),
)
DescribeTable("determines if stream server lists are equal",
func(newServers []ngxclient.StreamUpstreamServer, oldServers []ngxclient.StreamPeer, equal bool) {
Expect(serversEqual(newServers, oldServers)).To(Equal(equal))
},
Entry("different length",
[]ngxclient.StreamUpstreamServer{
{Server: "server1"},
},
[]ngxclient.StreamPeer{
{Server: "server1"},
{Server: "server2"},
},
false,
),
Entry("differing elements",
[]ngxclient.StreamUpstreamServer{
{Server: "server1"},
{Server: "server2"},
},
[]ngxclient.StreamPeer{
{Server: "server1"},
{Server: "server3"},
},
false,
),
Entry("same elements",
[]ngxclient.StreamUpstreamServer{
{Server: "server1"},
{Server: "server2"},
},
[]ngxclient.StreamPeer{
{Server: "server1"},
{Server: "server2"},
},
true,
),
)
})

var _ = Describe("getGatewayAddresses", func() {
Expand Down
Loading

0 comments on commit 8e2e2d8

Please sign in to comment.