Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cmd/opampsupervisor] update types #37108

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 61 additions & 60 deletions cmd/opampsupervisor/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (tl testLogger) Errorf(_ context.Context, format string, args ...any) {
tl.t.Logf(format, args...)
}

func defaultConnectingHandler(connectionCallbacks server.ConnectionCallbacksStruct) func(request *http.Request) types.ConnectionResponse {
func defaultConnectingHandler(connectionCallbacks types.ConnectionCallbacks) func(request *http.Request) types.ConnectionResponse {
return func(_ *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{
Accept: true,
Expand All @@ -73,11 +73,11 @@ func defaultConnectingHandler(connectionCallbacks server.ConnectionCallbacksStru
}
}

// onConnectingFuncFactory is a function that will be given to server.CallbacksStruct as
// onConnectingFuncFactory is a function that will be given to types.ConnectionCallbacks as
// OnConnectingFunc. This allows changing the ConnectionCallbacks both from the newOpAMPServer
// caller and inside of newOpAMP Server, and for custom implementations of the value for `Accept`
// in types.ConnectionResponse.
type onConnectingFuncFactory func(connectionCallbacks server.ConnectionCallbacksStruct) func(request *http.Request) types.ConnectionResponse
type onConnectingFuncFactory func(connectionCallbacks types.ConnectionCallbacks) func(request *http.Request) types.ConnectionResponse

type testingOpAMPServer struct {
addr string
Expand All @@ -87,38 +87,38 @@ type testingOpAMPServer struct {
shutdown func()
}

func newOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks server.ConnectionCallbacksStruct) *testingOpAMPServer {
func newOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks types.ConnectionCallbacks) *testingOpAMPServer {
s := newUnstartedOpAMPServer(t, connectingCallback, callbacks)
s.start()
return s
}

func newUnstartedOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks server.ConnectionCallbacksStruct) *testingOpAMPServer {
func newUnstartedOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks types.ConnectionCallbacks) *testingOpAMPServer {
var agentConn atomic.Value
var isAgentConnected atomic.Bool
var didShutdown atomic.Bool
connectedChan := make(chan bool)
s := server.New(testLogger{t: t})
onConnectedFunc := callbacks.OnConnectedFunc
callbacks.OnConnectedFunc = func(ctx context.Context, conn types.Connection) {
onConnectedFunc := callbacks.OnConnected
callbacks.OnConnected = func(ctx context.Context, conn types.Connection) {
if onConnectedFunc != nil {
onConnectedFunc(ctx, conn)
}
agentConn.Store(conn)
isAgentConnected.Store(true)
connectedChan <- true
}
onConnectionCloseFunc := callbacks.OnConnectionCloseFunc
callbacks.OnConnectionCloseFunc = func(conn types.Connection) {
onConnectionCloseFunc := callbacks.OnConnectionClose
callbacks.OnConnectionClose = func(conn types.Connection) {
isAgentConnected.Store(false)
connectedChan <- false
if onConnectionCloseFunc != nil {
onConnectionCloseFunc(conn)
}
}
handler, _, err := s.Attach(server.Settings{
Callbacks: server.CallbacksStruct{
OnConnectingFunc: connectingCallback(callbacks),
Callbacks: types.Callbacks{
OnConnecting: connectingCallback(callbacks),
},
})
require.NoError(t, err)
Expand Down Expand Up @@ -211,8 +211,8 @@ func TestSupervisorStartsCollectorWithRemoteConfig(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -287,8 +287,8 @@ func TestSupervisorStartsCollectorWithNoOpAMPServer(t *testing.T) {
require.NoError(t, os.WriteFile(remoteConfigFilePath, marshalledRemoteConfig, 0o600))

connected := atomic.Bool{}
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, server.ConnectionCallbacksStruct{
OnConnectedFunc: func(ctx context.Context, conn types.Connection) {
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, types.ConnectionCallbacks{
OnConnected: func(ctx context.Context, conn types.Connection) {
connected.Store(true)
},
})
Expand Down Expand Up @@ -331,19 +331,20 @@ func TestSupervisorStartsWithNoOpAMPServer(t *testing.T) {

configuredChan := make(chan struct{})
connected := atomic.Bool{}
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, server.ConnectionCallbacksStruct{
OnConnectedFunc: func(ctx context.Context, conn types.Connection) {
connected.Store(true)
},
OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
lastCfgHash := message.GetRemoteConfigStatus().GetLastRemoteConfigHash()
if bytes.Equal(lastCfgHash, hash) {
close(configuredChan)
}
server := newUnstartedOpAMPServer(t, defaultConnectingHandler,
types.ConnectionCallbacks{
OnConnected: func(ctx context.Context, conn types.Connection) {
connected.Store(true)
},
OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
lastCfgHash := message.GetRemoteConfigStatus().GetLastRemoteConfigHash()
if bytes.Equal(lastCfgHash, hash) {
close(configuredChan)
}

return &protobufs.ServerToAgent{}
},
})
return &protobufs.ServerToAgent{}
},
})
defer server.shutdown()

// The supervisor is started without a running OpAMP server.
Expand Down Expand Up @@ -415,8 +416,8 @@ func TestSupervisorRestartsCollectorAfterBadConfig(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.Health != nil {
healthReport.Store(message.Health)
}
Expand Down Expand Up @@ -501,8 +502,8 @@ func TestSupervisorConfiguresCapabilities(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
capabilities.Store(message.Capabilities)

return &protobufs.ServerToAgent{}
Expand Down Expand Up @@ -556,8 +557,8 @@ func TestSupervisorBootstrapsCollector(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.AgentDescription != nil {
agentDescription.Store(message.AgentDescription)
}
Expand Down Expand Up @@ -602,8 +603,8 @@ func TestSupervisorReportsEffectiveConfig(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -713,8 +714,8 @@ func TestSupervisorAgentDescriptionConfigApplies(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.AgentDescription != nil {
select {
case agentDescMessageChan <- message:
Expand Down Expand Up @@ -866,8 +867,8 @@ func TestSupervisorRestartCommand(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.Health != nil {
healthReport.Store(message.Health)
}
Expand Down Expand Up @@ -948,7 +949,7 @@ func TestSupervisorOpAMPConnectionSettings(t *testing.T) {
initialServer := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{})
types.ConnectionCallbacks{})

s := newSupervisor(t, "accepts_conn", map[string]string{"url": initialServer.addr})

Expand All @@ -960,11 +961,11 @@ func TestSupervisorOpAMPConnectionSettings(t *testing.T) {
newServer := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnConnectedFunc: func(_ context.Context, _ types.Connection) {
types.ConnectionCallbacks{
OnConnected: func(_ context.Context, _ types.Connection) {
connectedToNewServer.Store(true)
},
OnMessageFunc: func(_ context.Context, _ types.Connection, _ *protobufs.AgentToServer) *protobufs.ServerToAgent {
OnMessage: func(_ context.Context, _ types.Connection, _ *protobufs.AgentToServer) *protobufs.ServerToAgent {
return &protobufs.ServerToAgent{}
},
})
Expand Down Expand Up @@ -999,8 +1000,8 @@ func TestSupervisorRestartsWithLastReceivedConfig(t *testing.T) {
initialServer := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1043,8 +1044,8 @@ func TestSupervisorRestartsWithLastReceivedConfig(t *testing.T) {
newServer := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1087,8 +1088,8 @@ func TestSupervisorPersistsInstanceID(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
select {
case agentIDChan <- message.InstanceUid:
default:
Expand Down Expand Up @@ -1163,8 +1164,8 @@ func TestSupervisorPersistsNewInstanceID(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
select {
case agentIDChan <- message.InstanceUid:
default:
Expand Down Expand Up @@ -1242,7 +1243,7 @@ func TestSupervisorWritesAgentFilesToStorageDir(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{},
types.ConnectionCallbacks{},
)

s := newSupervisor(t, "basic", map[string]string{
Expand Down Expand Up @@ -1270,8 +1271,8 @@ func TestSupervisorStopsAgentProcessWithEmptyConfigMap(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1386,8 +1387,8 @@ func TestSupervisorLogging(t *testing.T) {
require.NoError(t, os.WriteFile(remoteCfgFilePath, marshalledRemoteCfg, 0o600))

connected := atomic.Bool{}
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, server.ConnectionCallbacksStruct{
OnConnectedFunc: func(ctx context.Context, conn types.Connection) {
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, types.ConnectionCallbacks{
OnConnected: func(ctx context.Context, conn types.Connection) {
connected.Store(true)
},
})
Expand Down Expand Up @@ -1449,8 +1450,8 @@ func TestSupervisorRemoteConfigApplyStatus(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1586,8 +1587,8 @@ func TestSupervisorOpAmpServerPort(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.ConnectionCallbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down
30 changes: 17 additions & 13 deletions cmd/opampsupervisor/supervisor/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,26 @@ import (
)

type flattenedSettings struct {
onMessageFunc func(conn serverTypes.Connection, message *protobufs.AgentToServer)
onConnectingFunc func(request *http.Request) (shouldConnect bool, rejectStatusCode int)
onConnectionCloseFunc func(conn serverTypes.Connection)
endpoint string
onMessage func(conn serverTypes.Connection, message *protobufs.AgentToServer)
onConnecting func(request *http.Request) (shouldConnect bool, rejectStatusCode int)
onConnectionClose func(conn serverTypes.Connection)
endpoint string
}

func (fs flattenedSettings) toServerSettings() server.StartSettings {
return server.StartSettings{
Settings: server.Settings{
Callbacks: fs,
Callbacks: serverTypes.Callbacks{
OnConnecting: fs.OnConnecting,
},
},
ListenEndpoint: fs.endpoint,
}
}

func (fs flattenedSettings) OnConnecting(request *http.Request) serverTypes.ConnectionResponse {
if fs.onConnectingFunc != nil {
shouldConnect, rejectStatusCode := fs.onConnectingFunc(request)
if fs.onConnecting != nil {
shouldConnect, rejectStatusCode := fs.onConnecting(request)
if !shouldConnect {
return serverTypes.ConnectionResponse{
Accept: false,
Expand All @@ -40,23 +42,25 @@ func (fs flattenedSettings) OnConnecting(request *http.Request) serverTypes.Conn
}

return serverTypes.ConnectionResponse{
Accept: true,
ConnectionCallbacks: fs,
Accept: true,
ConnectionCallbacks: serverTypes.ConnectionCallbacks{
OnMessage: fs.OnMessage,
},
}
}

func (fs flattenedSettings) OnConnected(_ context.Context, _ serverTypes.Connection) {}

func (fs flattenedSettings) OnMessage(_ context.Context, conn serverTypes.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if fs.onMessageFunc != nil {
fs.onMessageFunc(conn, message)
if fs.onMessage != nil {
fs.onMessage(conn, message)
}

return &protobufs.ServerToAgent{}
}

func (fs flattenedSettings) OnConnectionClose(conn serverTypes.Connection) {
if fs.onConnectionCloseFunc != nil {
fs.onConnectionCloseFunc(conn)
if fs.onConnectionClose != nil {
fs.onConnectionClose(conn)
}
}
Loading
Loading