From ca0e1e68754fdf9b131419e0629ed2d99bdae713 Mon Sep 17 00:00:00 2001 From: xu0o0 Date: Fri, 15 Dec 2023 01:23:59 +0800 Subject: [PATCH] feat: allow to get the network address server is listening on (#217) Add an Addr() method to fetch the network address Server is listening on. I think this may be useful for the supervisor to start a server with 0 port to allocate an ephemeral port and get the allocated port for rendering an agent config. --- server/server.go | 5 +++++ server/serverimpl.go | 8 ++++++++ server/serverimpl_test.go | 42 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/server/server.go b/server/server.go index d90db7b5..929f8675 100644 --- a/server/server.go +++ b/server/server.go @@ -62,4 +62,9 @@ type OpAMPServer interface { // Stop accepting new connections and close all current connections. This should // block until all connections are closed. Stop(ctx context.Context) error + + // Addr returns the network address Server is listening on. Nil if not started. + // Typically used to fetch the port when ListenEndpoint's port is specified as 0 to + // allocate an ephemeral port. + Addr() net.Addr } diff --git a/server/serverimpl.go b/server/serverimpl.go index 1c6a3f37..27a95874 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -40,6 +40,9 @@ type server struct { // The listening HTTP Server after successful Start() call. Nil if Start() // is not called or was not successful. httpServer *http.Server + + // The network address Server is listening on. Nil if not started. + addr net.Addr } var _ OpAMPServer = (*server)(nil) @@ -118,6 +121,7 @@ func (s *server) startHttpServer(listenAddr string, serveFunc func(l net.Listene if err != nil { return err } + s.addr = ln.Addr() // Begin serving connections in the background. go func() { @@ -143,6 +147,10 @@ func (s *server) Stop(ctx context.Context) error { return nil } +func (s *server) Addr() net.Addr { + return s.addr +} + func (s *server) httpHandler(w http.ResponseWriter, req *http.Request) { var connectionCallbacks serverTypes.ConnectionCallbacks if s.settings.Callbacks != nil { diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index 04827d38..de08ccb1 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -57,6 +57,48 @@ func TestServerStartStop(t *testing.T) { assert.NoError(t, err) } +func TestServerAddrWithNonZeroPort(t *testing.T) { + srv := New(&sharedinternal.NopLogger{}) + require.NotNil(t, srv) + + // Nil if not started + assert.Nil(t, srv.Addr()) + + addr := testhelpers.GetAvailableLocalAddress() + + err := srv.Start(StartSettings{ + ListenEndpoint: addr, + ListenPath: "/", + }) + assert.NoError(t, err) + + assert.Equal(t, addr, srv.Addr().String()) + + err = srv.Stop(context.Background()) + assert.NoError(t, err) +} + +func TestServerAddrWithZeroPort(t *testing.T) { + srv := New(&sharedinternal.NopLogger{}) + require.NotNil(t, srv) + + // Nil if not started + assert.Nil(t, srv.Addr()) + + err := srv.Start(StartSettings{ + ListenEndpoint: "127.0.0.1:0", + ListenPath: "/", + }) + assert.NoError(t, err) + + // should be listening on an non-zero ephemeral port + assert.NotEqual(t, "127.0.0.1:0", srv.Addr().String()) + assert.Regexp(t, `^127.0.0.1:\d+`, srv.Addr().String()) + + err = srv.Stop(context.Background()) + assert.NoError(t, err) +} + func TestServerStartRejectConnection(t *testing.T) { callbacks := CallbacksStruct{ OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {