Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cmd/opampsupervisor] update types #37108

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions cmd/opampsupervisor/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (tl testLogger) Errorf(_ context.Context, format string, args ...any) {
tl.t.Logf(format, args...)
}

func defaultConnectingHandler(connectionCallbacks server.ConnectionCallbacksStruct) func(request *http.Request) types.ConnectionResponse {
func defaultConnectingHandler(connectionCallbacks types.Callbacks) func(request *http.Request) types.ConnectionResponse {
return func(_ *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{
Accept: true,
Expand All @@ -73,11 +73,11 @@ func defaultConnectingHandler(connectionCallbacks server.ConnectionCallbacksStru
}
}

// onConnectingFuncFactory is a function that will be given to server.CallbacksStruct as
// onConnectingFuncFactory is a function that will be given to types.Callbacks as
// OnConnectingFunc. This allows changing the ConnectionCallbacks both from the newOpAMPServer
// caller and inside of newOpAMP Server, and for custom implementations of the value for `Accept`
// in types.ConnectionResponse.
type onConnectingFuncFactory func(connectionCallbacks server.ConnectionCallbacksStruct) func(request *http.Request) types.ConnectionResponse
type onConnectingFuncFactory func(connectionCallbacks types.Callbacks) func(request *http.Request) types.ConnectionResponse

type testingOpAMPServer struct {
addr string
Expand All @@ -87,13 +87,13 @@ type testingOpAMPServer struct {
shutdown func()
}

func newOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks server.ConnectionCallbacksStruct) *testingOpAMPServer {
func newOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks types.Callbacks) *testingOpAMPServer {
s := newUnstartedOpAMPServer(t, connectingCallback, callbacks)
s.start()
return s
}

func newUnstartedOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks server.ConnectionCallbacksStruct) *testingOpAMPServer {
func newUnstartedOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFactory, callbacks types.Callbacks) *testingOpAMPServer {
var agentConn atomic.Value
var isAgentConnected atomic.Bool
var didShutdown atomic.Bool
Expand All @@ -117,8 +117,8 @@ func newUnstartedOpAMPServer(t *testing.T, connectingCallback onConnectingFuncFa
}
}
handler, _, err := s.Attach(server.Settings{
Callbacks: server.CallbacksStruct{
OnConnectingFunc: connectingCallback(callbacks),
Callbacks: types.Callbacks{
OnConnecting: connectingCallback(callbacks),
},
})
require.NoError(t, err)
Expand Down Expand Up @@ -211,8 +211,8 @@ func TestSupervisorStartsCollectorWithRemoteConfig(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -287,8 +287,8 @@ func TestSupervisorStartsCollectorWithNoOpAMPServer(t *testing.T) {
require.NoError(t, os.WriteFile(remoteConfigFilePath, marshalledRemoteConfig, 0o600))

connected := atomic.Bool{}
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, server.ConnectionCallbacksStruct{
OnConnectedFunc: func(ctx context.Context, conn types.Connection) {
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, types.Callbacks{
OnConnected: func(ctx context.Context, conn types.Connection) {
connected.Store(true)
},
})
Expand Down Expand Up @@ -331,11 +331,11 @@ func TestSupervisorStartsWithNoOpAMPServer(t *testing.T) {

configuredChan := make(chan struct{})
connected := atomic.Bool{}
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, server.ConnectionCallbacksStruct{
OnConnectedFunc: func(ctx context.Context, conn types.Connection) {
server := newUnstartedOpAMPServer(t, defaultConnectingHandler, types.Callbacks{
OnConnected: func(ctx context.Context, conn types.Connection) {
connected.Store(true)
},
OnMessageFunc: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
OnMessage: func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
lastCfgHash := message.GetRemoteConfigStatus().GetLastRemoteConfigHash()
if bytes.Equal(lastCfgHash, hash) {
close(configuredChan)
Expand Down Expand Up @@ -415,8 +415,8 @@ func TestSupervisorRestartsCollectorAfterBadConfig(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.Health != nil {
healthReport.Store(message.Health)
}
Expand Down Expand Up @@ -501,8 +501,8 @@ func TestSupervisorConfiguresCapabilities(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
capabilities.Store(message.Capabilities)

return &protobufs.ServerToAgent{}
Expand Down Expand Up @@ -556,8 +556,8 @@ func TestSupervisorBootstrapsCollector(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.AgentDescription != nil {
agentDescription.Store(message.AgentDescription)
}
Expand Down Expand Up @@ -602,8 +602,8 @@ func TestSupervisorReportsEffectiveConfig(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -713,8 +713,8 @@ func TestSupervisorAgentDescriptionConfigApplies(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.AgentDescription != nil {
select {
case agentDescMessageChan <- message:
Expand Down Expand Up @@ -866,8 +866,8 @@ func TestSupervisorRestartCommand(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.Health != nil {
healthReport.Store(message.Health)
}
Expand Down Expand Up @@ -999,8 +999,8 @@ func TestSupervisorRestartsWithLastReceivedConfig(t *testing.T) {
initialServer := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1043,8 +1043,8 @@ func TestSupervisorRestartsWithLastReceivedConfig(t *testing.T) {
newServer := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1087,8 +1087,8 @@ func TestSupervisorPersistsInstanceID(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
select {
case agentIDChan <- message.InstanceUid:
default:
Expand Down Expand Up @@ -1163,8 +1163,8 @@ func TestSupervisorPersistsNewInstanceID(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
select {
case agentIDChan <- message.InstanceUid:
default:
Expand Down Expand Up @@ -1270,8 +1270,8 @@ func TestSupervisorStopsAgentProcessWithEmptyConfigMap(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1449,8 +1449,8 @@ func TestSupervisorRemoteConfigApplyStatus(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down Expand Up @@ -1586,8 +1586,8 @@ func TestSupervisorOpAmpServerPort(t *testing.T) {
server := newOpAMPServer(
t,
defaultConnectingHandler,
server.ConnectionCallbacksStruct{
OnMessageFunc: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
types.Callbacks{
OnMessage: func(_ context.Context, _ types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if message.EffectiveConfig != nil {
config := message.EffectiveConfig.ConfigMap.ConfigMap[""]
if config != nil {
Expand Down
30 changes: 17 additions & 13 deletions cmd/opampsupervisor/supervisor/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,26 @@ import (
)

type flattenedSettings struct {
onMessageFunc func(conn serverTypes.Connection, message *protobufs.AgentToServer)
onConnectingFunc func(request *http.Request) (shouldConnect bool, rejectStatusCode int)
onConnectionCloseFunc func(conn serverTypes.Connection)
endpoint string
onMessage func(conn serverTypes.Connection, message *protobufs.AgentToServer)
onConnecting func(request *http.Request) (shouldConnect bool, rejectStatusCode int)
onConnectionClose func(conn serverTypes.Connection)
endpoint string
}

func (fs flattenedSettings) toServerSettings() server.StartSettings {
return server.StartSettings{
Settings: server.Settings{
Callbacks: fs,
Callbacks: serverTypes.Callbacks{
OnConnecting: fs.OnConnecting,
},
},
ListenEndpoint: fs.endpoint,
}
}

func (fs flattenedSettings) OnConnecting(request *http.Request) serverTypes.ConnectionResponse {
if fs.onConnectingFunc != nil {
shouldConnect, rejectStatusCode := fs.onConnectingFunc(request)
if fs.onConnecting != nil {
shouldConnect, rejectStatusCode := fs.onConnecting(request)
if !shouldConnect {
return serverTypes.ConnectionResponse{
Accept: false,
Expand All @@ -40,23 +42,25 @@ func (fs flattenedSettings) OnConnecting(request *http.Request) serverTypes.Conn
}

return serverTypes.ConnectionResponse{
Accept: true,
ConnectionCallbacks: fs,
Accept: true,
ConnectionCallbacks: serverTypes.ConnectionCallbacks{
OnMessage: fs.OnMessage,
},
}
}

func (fs flattenedSettings) OnConnected(_ context.Context, _ serverTypes.Connection) {}

func (fs flattenedSettings) OnMessage(_ context.Context, conn serverTypes.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if fs.onMessageFunc != nil {
fs.onMessageFunc(conn, message)
if fs.onMessage != nil {
fs.onMessage(conn, message)
}

return &protobufs.ServerToAgent{}
}

func (fs flattenedSettings) OnConnectionClose(conn serverTypes.Connection) {
if fs.onConnectionCloseFunc != nil {
fs.onConnectionCloseFunc(conn)
if fs.onConnectionClose != nil {
fs.onConnectionClose(conn)
}
}
8 changes: 4 additions & 4 deletions cmd/opampsupervisor/supervisor/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func Test_flattenedSettings_OnConnecting(t *testing.T) {
t.Run("accept connection", func(t *testing.T) {
onConnectingFuncCalled := false
fs := flattenedSettings{
onConnectingFunc: func(_ *http.Request) (shouldConnect bool, rejectStatusCode int) {
onConnecting: func(_ *http.Request) (shouldConnect bool, rejectStatusCode int) {
onConnectingFuncCalled = true
return true, 0
},
Expand All @@ -43,7 +43,7 @@ func Test_flattenedSettings_OnConnecting(t *testing.T) {
t.Run("do not accept connection", func(t *testing.T) {
onConnectingFuncCalled := false
fs := flattenedSettings{
onConnectingFunc: func(_ *http.Request) (shouldConnect bool, rejectStatusCode int) {
onConnecting: func(_ *http.Request) (shouldConnect bool, rejectStatusCode int) {
onConnectingFuncCalled = true
return false, 500
},
Expand All @@ -60,7 +60,7 @@ func Test_flattenedSettings_OnConnecting(t *testing.T) {
func Test_flattenedSettings_OnMessage(t *testing.T) {
onMessageFuncCalled := false
fs := flattenedSettings{
onMessageFunc: func(_ serverTypes.Connection, _ *protobufs.AgentToServer) {
onMessage: func(_ serverTypes.Connection, _ *protobufs.AgentToServer) {
onMessageFuncCalled = true
},
}
Expand All @@ -74,7 +74,7 @@ func Test_flattenedSettings_OnMessage(t *testing.T) {
func Test_flattenedSettings_OnConnectionClose(t *testing.T) {
onConnectionCloseFuncCalled := false
fs := flattenedSettings{
onConnectionCloseFunc: func(_ serverTypes.Connection) {
onConnectionClose: func(_ serverTypes.Connection) {
onConnectionCloseFuncCalled = true
},
}
Expand Down
Loading
Loading