diff --git a/coderd/coderd.go b/coderd/coderd.go index e04f13d367c6e..b5b243c44135e 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1399,9 +1399,6 @@ func (api *API) Close() error { default: api.cancel() } - if api.derpCloseFunc != nil { - api.derpCloseFunc() - } wsDone := make(chan struct{}) timer := time.NewTimer(10 * time.Second) @@ -1427,11 +1424,14 @@ func (api *API) Close() error { api.updateChecker.Close() } _ = api.workspaceAppServer.Close() + _ = api.agentProvider.Close() + if api.derpCloseFunc != nil { + api.derpCloseFunc() + } coordinator := api.TailnetCoordinator.Load() if coordinator != nil { _ = (*coordinator).Close() } - _ = api.agentProvider.Close() _ = api.statsReporter.Close() _ = api.NetworkTelemetryBatcher.Close() return nil diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 077d704be1300..2a713787f499e 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -253,6 +253,7 @@ var ( rbac.ResourceDeploymentConfig.Type: {policy.ActionCreate, policy.ActionUpdate, policy.ActionDelete}, rbac.ResourceNotificationPreference.Type: {policy.ActionCreate, policy.ActionUpdate, policy.ActionDelete}, rbac.ResourceNotificationTemplate.Type: {policy.ActionCreate, policy.ActionUpdate, policy.ActionDelete}, + rbac.ResourceTailnetCoordinator.Type: {policy.ActionCreate, policy.ActionUpdate, policy.ActionDelete}, }), Org: map[string][]rbac.Permission{}, User: []rbac.Permission{}, diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go index f8530ca990aed..5072726015157 100644 --- a/enterprise/tailnet/pgcoord.go +++ b/enterprise/tailnet/pgcoord.go @@ -506,7 +506,8 @@ func newBinder(ctx context.Context, b.logger.Debug(b.ctx, "updating peers to lost") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + //nolint:gocritic // provisioner is system + ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(context.Background()), time.Second*15) defer cancel() err := b.store.UpdateTailnetPeerStatusByCoordinator(ctx, database.UpdateTailnetPeerStatusByCoordinatorParams{ CoordinatorID: b.coordinatorID, diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go index dc9b4e2806c35..da5118049cf7a 100644 --- a/enterprise/tailnet/pgcoord_test.go +++ b/enterprise/tailnet/pgcoord_test.go @@ -3,7 +3,12 @@ package tailnet_test import ( "context" "database/sql" + "io" + "net" + "net/http" + "net/http/httptest" "net/netip" + "net/url" "sync" "testing" "time" @@ -18,16 +23,27 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/agent" + "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/enterprise/coderd" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/enterprise/tailnet" + "github.com/coder/coder/v2/provisioner/echo" agpl "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" agpltest "github.com/coder/coder/v2/tailnet/test" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" + "github.com/coder/serpent" ) func TestMain(m *testing.M) { @@ -913,6 +929,287 @@ func TestPGCoordinatorDual_PeerReconnect(t *testing.T) { p2.AssertNeverUpdateKind(p1.ID, proto.CoordinateResponse_PeerUpdate_DISCONNECTED) } +// restartableListener is a TCP listener that can have all of it's connections +// severed on demand. +type restartableListener struct { + net.Listener + mu sync.Mutex + conns []net.Conn +} + +func (l *restartableListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + l.mu.Lock() + l.conns = append(l.conns, conn) + l.mu.Unlock() + return conn, nil +} + +func (l *restartableListener) CloseConnections() { + l.mu.Lock() + defer l.mu.Unlock() + for _, conn := range l.conns { + _ = conn.Close() + } + l.conns = nil +} + +type restartableTestServer struct { + options *coderdenttest.Options + rl *restartableListener + + mu sync.Mutex + api *coderd.API + closer io.Closer +} + +func newRestartableTestServer(t *testing.T, options *coderdenttest.Options) (*codersdk.Client, codersdk.CreateFirstUserResponse, *restartableTestServer) { + t.Helper() + if options == nil { + options = &coderdenttest.Options{} + } + + s := &restartableTestServer{ + options: options, + } + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + api := s.api + s.mu.Unlock() + + if api == nil { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte("server is not started")) + return + } + api.AGPL.RootHandler.ServeHTTP(w, r) + })) + s.rl = &restartableListener{Listener: srv.Listener} + srv.Listener = s.rl + srv.Start() + t.Cleanup(srv.Close) + + u, err := url.Parse(srv.URL) + require.NoError(t, err, "failed to parse server URL") + s.options.AccessURL = u + + client, firstUser := s.startWithFirstUser(t) + client.URL = u + return client, firstUser, s +} + +func (s *restartableTestServer) Stop(t *testing.T) { + t.Helper() + + s.mu.Lock() + closer := s.closer + s.closer = nil + api := s.api + s.api = nil + s.mu.Unlock() + + if closer != nil { + err := closer.Close() + require.NoError(t, err) + } + if api != nil { + err := api.Close() + require.NoError(t, err) + } + + s.rl.CloseConnections() +} + +func (s *restartableTestServer) Start(t *testing.T) { + t.Helper() + _, _ = s.startWithFirstUser(t) +} + +func (s *restartableTestServer) startWithFirstUser(t *testing.T) (client *codersdk.Client, firstUser codersdk.CreateFirstUserResponse) { + t.Helper() + s.mu.Lock() + defer s.mu.Unlock() + + if s.closer != nil || s.api != nil { + t.Fatal("server already started, close must be called first") + } + // This creates it's own TCP listener unfortunately, but it's not being + // used in this test. + client, s.closer, s.api, firstUser = coderdenttest.NewWithAPI(t, s.options) + + // Never add the first user or license on subsequent restarts. + s.options.DontAddFirstUser = true + s.options.DontAddLicense = true + + return client, firstUser +} + +// Test_CoordinatorRollingRestart tests that two peers can maintain a connection +// without forgetting about each other when a HA coordinator does a rolling +// restart. +// +// We had a few issues with this in the past: +// 1. We didn't allow clients to maintain their peer ID after a reconnect, +// which resulted in the other peer thinking the client was a new peer. +// (This is fixed and independently tested in AGPL code) +// 2. HA coordinators would delete all peers (via FK constraints) when they +// were closed, which meant tunnels would be deleted and peers would be +// notified that the other peer was permanently gone. +// (This is fixed and independently tested above) +// +// This test uses a real server and real clients. +func TestConn_CoordinatorRollingRestart(t *testing.T) { + t.Parallel() + + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + + // Although DERP will have connection issues until the connection is + // reestablished, any open connections should be maintained. + // + // Direct connections should be able to transmit packets throughout the + // restart without issue. + for _, direct := range []bool{true, false} { + direct := direct + name := "DERP" + if direct { + name = "Direct" + } + + t.Run(name, func(t *testing.T) { + t.Parallel() + + store, ps := dbtestutil.NewDB(t) + dv := coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + dv.DERP.Config.BlockDirect = serpent.Bool(!direct) + }) + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + // Create two restartable test servers with the same database. + client1, user, s1 := newRestartableTestServer(t, &coderdenttest.Options{ + DontAddFirstUser: false, + DontAddLicense: false, + Options: &coderdtest.Options{ + Logger: ptr.Ref(logger.Named("server1")), + Database: store, + Pubsub: ps, + DeploymentValues: dv, + IncludeProvisionerDaemon: true, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureHighAvailability: 1, + }, + }, + }) + client2, _, s2 := newRestartableTestServer(t, &coderdenttest.Options{ + DontAddFirstUser: true, + DontAddLicense: true, + Options: &coderdtest.Options{ + Logger: ptr.Ref(logger.Named("server2")), + Database: store, + Pubsub: ps, + DeploymentValues: dv, + }, + }) + client2.SetSessionToken(client1.SessionToken()) + + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client1, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.PlanComplete, + ProvisionApply: echo.ProvisionApplyWithAgent(authToken), + }) + template := coderdtest.CreateTemplate(t, client1, user.OrganizationID, version.ID) + _ = coderdtest.AwaitTemplateVersionJobCompleted(t, client1, version.ID) + workspace := coderdtest.CreateWorkspace(t, client1, template.ID) + _ = coderdtest.AwaitWorkspaceBuildJobCompleted(t, client1, workspace.LatestBuild.ID) + + // Agent connects via the first coordinator. + _ = agenttest.New(t, client1.URL, authToken, func(o *agent.Options) { + o.Logger = logger.Named("agent1") + }) + resources := coderdtest.NewWorkspaceAgentWaiter(t, client1, workspace.ID).Wait() + + agentID := uuid.Nil + for _, r := range resources { + for _, a := range r.Agents { + agentID = a.ID + break + } + } + require.NotEqual(t, uuid.Nil, agentID) + + // Client connects via the second coordinator. + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + workspaceClient2 := workspacesdk.New(client2) + conn, err := workspaceClient2.DialAgent(ctx, agentID, &workspacesdk.DialAgentOptions{ + Logger: logger.Named("client"), + }) + require.NoError(t, err) + defer conn.Close() + + _, p2p, _, err := conn.Ping(ctx) + require.NoError(t, err) + require.Equal(t, direct, p2p, "mismatched p2p state") + + // Open a TCP server and connection to it through the tunnel that + // should be maintained throughout the restart. + tcpServerAddr := tcpEchoServer(t) + tcpConn, err := conn.DialContext(ctx, "tcp", tcpServerAddr) + require.NoError(t, err) + defer tcpConn.Close() + writeReadEcho(t, ctx, tcpConn) + + // Stop the first server. + logger.Info(ctx, "test: stopping server 1") + s1.Stop(t) + + // Pings should fail on DERP but succeed on direct connections. + pingCtx, pingCancel := context.WithTimeout(ctx, 2*time.Second) //nolint:gocritic // it's going to hang and timeout for DERP, so this needs to be short + defer pingCancel() + _, p2p, _, err = conn.Ping(pingCtx) + if direct { + require.NoError(t, err) + require.True(t, p2p, "expected direct connection") + } else { + require.ErrorIs(t, err, context.DeadlineExceeded) + } + + // The existing TCP connection should still be working if we're + // using direct connections. + if direct { + writeReadEcho(t, ctx, tcpConn) + } + + // Start the first server again. + logger.Info(ctx, "test: starting server 1") + s1.Start(t) + + // Restart the second server. + logger.Info(ctx, "test: stopping server 2") + s2.Stop(t) + logger.Info(ctx, "test: starting server 2") + s2.Start(t) + + // Pings should eventually succeed on both DERP and direct + // connections. + require.True(t, conn.AwaitReachable(ctx)) + _, p2p, _, err = conn.Ping(ctx) + require.NoError(t, err) + require.Equal(t, direct, p2p, "mismatched p2p state") + + // The existing TCP connection should still be working. + writeReadEcho(t, ctx, tcpConn) + }) + } +} + func assertEventuallyStatus(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID, status database.TailnetStatus) { t.Helper() assert.Eventually(t, func() bool { @@ -978,3 +1275,53 @@ func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) { }) require.NoError(c.t, err) } + +func tcpEchoServer(t *testing.T) string { + var listenerWg sync.WaitGroup + tcpListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + tcpListener.Close() + listenerWg.Wait() + }) + listenerWg.Add(1) + go func() { + defer listenerWg.Done() + for { + conn, err := tcpListener.Accept() + if err != nil { + return + } + listenerWg.Add(1) + go func() { + defer listenerWg.Done() + defer conn.Close() + _, _ = io.Copy(conn, conn) + }() + } + }() + + return tcpListener.Addr().String() +} + +func writeReadEcho(t *testing.T, ctx context.Context, conn net.Conn) { + const msg = "hello, world" + + deadline, ok := ctx.Deadline() + if ok { + conn.SetWriteDeadline(deadline) + defer conn.SetWriteDeadline(time.Time{}) + conn.SetReadDeadline(deadline) + defer conn.SetReadDeadline(time.Time{}) + } + + // Write a message + _, err := conn.Write([]byte(msg)) + require.NoError(t, err) + + // Read the message back + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + require.Equal(t, msg, string(buf[:n])) +}