diff --git a/CHANGELOG.md b/CHANGELOG.md index 4571ab999b..839656df99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,10 @@ Canonical reference for changes, improvements, and bugfixes for Boundary. ## Next ### Bug Fixes - +* sessions: Sessions and session connections have been refactored +to better isolate transactions and prevent resource contention that caused deadlocks. +([Issue](https://github.com/hashicorp/boundary/issues/1812), + [PR](https://github.com/hashicorp/boundary/pull/1919)) * scheduler: Fix bug that causes erroneous logs when racing controllers attempted to run jobs ([Issue](https://github.com/hashicorp/boundary/issues/1903), diff --git a/internal/cmd/commands/server/worker_shutdown_reload_test.go b/internal/cmd/commands/server/worker_shutdown_reload_test.go index 5c0b05e018..3fa771c7d4 100644 --- a/internal/cmd/commands/server/worker_shutdown_reload_test.go +++ b/internal/cmd/commands/server/worker_shutdown_reload_test.go @@ -163,7 +163,7 @@ func TestServer_ShutdownWorker(t *testing.T) { // Connection should fail, and the session should be closed on the DB. sConn.TestSendRecvFail(t) - sess.ExpectConnectionStateOnController(ctx, t, controllerCmd.controller.SessionRepoFn, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, controllerCmd.controller.ConnectionRepoFn, session.StatusClosed) // We're done! Shutdown the controller, and that's it. close(controllerCmd.ShutdownCh) diff --git a/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql b/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql index ef7a7e9bf7..dd6a5e6974 100644 --- a/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql +++ b/internal/db/schema/migrations/oss/postgres/0/51_connection.up.sql @@ -175,6 +175,7 @@ begin; after insert on session_connection for each row execute procedure insert_new_connection_state(); +-- Replaced in 27/01_disable_terminate_session.up.sql -- update_connection_state_on_closed_reason() is used in an update trigger on the -- session_connection table. it will insert a state of "closed" in -- session_connection_state for the closed session connection. @@ -284,6 +285,7 @@ begin; create trigger insert_session_connection_state before insert on session_connection_state for each row execute procedure insert_session_connection_state(); +-- Removed in 27/01_disable_terminate_session.up.sql -- terminate_session_if_possible takes a session id and terminates the session -- if the following conditions are met: -- * the session is expired and all its connections are closed. diff --git a/internal/db/schema/migrations/oss/postgres/0/69_wh_session_facts.up.sql b/internal/db/schema/migrations/oss/postgres/0/69_wh_session_facts.up.sql index a2df68b352..a8c5b8db2c 100644 --- a/internal/db/schema/migrations/oss/postgres/0/69_wh_session_facts.up.sql +++ b/internal/db/schema/migrations/oss/postgres/0/69_wh_session_facts.up.sql @@ -146,6 +146,7 @@ begin; for each row execute function wh_insert_session_connection(); + -- Updated in 27/01_disable_terminate_session.up.sql -- wh_update_session_connection returns an after update trigger for the -- session_connection table which updates a row in -- wh_session_connection_accumulating_fact for the session connection. diff --git a/internal/db/schema/migrations/oss/postgres/16/05_wh_credential_dimension.up.sql b/internal/db/schema/migrations/oss/postgres/16/05_wh_credential_dimension.up.sql index 4627b57a0f..eed53efee4 100644 --- a/internal/db/schema/migrations/oss/postgres/16/05_wh_credential_dimension.up.sql +++ b/internal/db/schema/migrations/oss/postgres/16/05_wh_credential_dimension.up.sql @@ -8,6 +8,7 @@ begin; alter table wh_session_connection_accumulating_fact alter column credential_group_key drop default; + -- Updated in 27/01_disable_terminate_session.up.sql -- replaces function from 15/01_wh_rename_key_columns.up.sql drop trigger wh_insert_session_connection on session_connection; drop function wh_insert_session_connection; diff --git a/internal/db/schema/migrations/oss/postgres/27/01_disable_terminate_session.up.sql b/internal/db/schema/migrations/oss/postgres/27/01_disable_terminate_session.up.sql new file mode 100644 index 0000000000..c1b8db55c9 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/27/01_disable_terminate_session.up.sql @@ -0,0 +1,36 @@ +begin; +-- Replaces function from 0/51_connection.up.sql +-- Remove call to terminate_session_if_possible +drop trigger update_connection_state_on_closed_reason on session_connection; +drop function update_connection_state_on_closed_reason(); +create function + update_connection_state_on_closed_reason() + returns trigger +as $$ + begin + if new.closed_reason is not null then + -- check to see if there's a closed state already, before inserting a new one. + perform from + session_connection_state cs + where + cs.connection_id = new.public_id and + cs.state = 'closed'; + if not found then + insert into session_connection_state (connection_id, state) + values + (new.public_id, 'closed'); + end if; + end if; + return new; + end; +$$ language plpgsql; + +create trigger + update_connection_state_on_closed_reason + after update of closed_reason on session_connection + for each row execute procedure update_connection_state_on_closed_reason(); + +-- Remove function, defined in 0/51_connection.up.sql +drop function terminate_session_if_possible; + +commit; \ No newline at end of file diff --git a/internal/db/schema/migrations/oss/postgres/27/02_wh_session_facts.up.sql b/internal/db/schema/migrations/oss/postgres/27/02_wh_session_facts.up.sql new file mode 100644 index 0000000000..8f781c93de --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/27/02_wh_session_facts.up.sql @@ -0,0 +1,114 @@ +begin; +-- Updating definition from 16/05_wh_credential_dimension.up.sql +-- Remove call to wh_rollup_connections(new.session_id) from function +drop trigger wh_insert_session_connection on session_connection; +drop function wh_insert_session_connection(); +create function wh_insert_session_connection() + returns trigger +as $$ +declare +new_row wh_session_connection_accumulating_fact%rowtype; +begin +with + authorized_timestamp (date_dim_key, time_dim_key, ts) as ( + select wh_date_key(start_time), wh_time_key(start_time), start_time + from session_connection_state + where connection_id = new.public_id + and state = 'authorized' + ), + session_dimension (host_dim_key, user_dim_key, credential_group_dim_key) as ( + select host_key, user_key, credential_group_key + from wh_session_accumulating_fact + where session_id = new.session_id + ) +insert into wh_session_connection_accumulating_fact ( + connection_id, + session_id, + host_key, + user_key, + credential_group_key, + connection_authorized_date_key, + connection_authorized_time_key, + connection_authorized_time, + client_tcp_address, + client_tcp_port_number, + endpoint_tcp_address, + endpoint_tcp_port_number, + bytes_up, + bytes_down + ) +select new.public_id, + new.session_id, + session_dimension.host_dim_key, + session_dimension.user_dim_key, + session_dimension.credential_group_dim_key, + authorized_timestamp.date_dim_key, + authorized_timestamp.time_dim_key, + authorized_timestamp.ts, + new.client_tcp_address, + new.client_tcp_port, + new.endpoint_tcp_address, + new.endpoint_tcp_port, + new.bytes_up, + new.bytes_down +from authorized_timestamp, + session_dimension + returning * into strict new_row; +return null; +end; +$$ language plpgsql; + +create trigger wh_insert_session_connection + after insert on session_connection + for each row + execute function wh_insert_session_connection(); + +-- Updating definition from 0/69_wh_session_facts.up.sql +-- Remove call to wh_rollup_connections(new.session_id) from function +drop trigger wh_update_session_connection on session_connection; +drop function wh_update_session_connection; +create function wh_update_session_connection() + returns trigger +as $$ +declare +updated_row wh_session_connection_accumulating_fact%rowtype; +begin +update wh_session_connection_accumulating_fact +set client_tcp_address = new.client_tcp_address, + client_tcp_port_number = new.client_tcp_port, + endpoint_tcp_address = new.endpoint_tcp_address, + endpoint_tcp_port_number = new.endpoint_tcp_port, + bytes_up = new.bytes_up, + bytes_down = new.bytes_down +where connection_id = new.public_id + returning * into strict updated_row; +return null; +end; +$$ language plpgsql; + +create trigger wh_update_session_connection + after update on session_connection + for each row + execute function wh_update_session_connection(); + +create function + wh_session_rollup() + returns trigger +as $$ +begin + if new.termination_reason is not null then + -- Rollup will fail if no connections were made for a session + if exists (select from session_connection where session_id = new.public_id) then + perform wh_rollup_connections(new.public_id); + end if; + end if; +return null; +end; +$$ language plpgsql; + +create trigger + wh_rollup_connections_on_session_termination + after update of termination_reason on session + for each row execute procedure wh_session_rollup(); + +commit; \ No newline at end of file diff --git a/internal/db/schema/migrations/oss/postgres_24_01_test.go b/internal/db/schema/migrations/oss/postgres_24_01_test.go index a03a9ca26d..bb1893d4be 100644 --- a/internal/db/schema/migrations/oss/postgres_24_01_test.go +++ b/internal/db/schema/migrations/oss/postgres_24_01_test.go @@ -123,6 +123,7 @@ func TestMigrations_SessionState(t *testing.T) { require.Equal(want, state) sessionRepo, err := session.NewRepository(rw, rw, kmsCache) + connectionRepo, err := session.NewConnectionRepository(ctx, rw, rw, kmsCache) require.NoError(err) // Ensure session is cancelled @@ -138,7 +139,7 @@ func TestMigrations_SessionState(t *testing.T) { require.Equal([]string{"canceled"}, sessionTermReason) // Ensure connection is also cancelled - connections, err := sessionRepo.ListConnectionsBySessionId(ctx, repoSessionId) + connections, err := connectionRepo.ListConnectionsBySessionId(ctx, repoSessionId) require.NoError(err) var connTermReason []string diff --git a/internal/servers/controller/common/common.go b/internal/servers/controller/common/common.go index 78a9f56eb7..4945b44d96 100644 --- a/internal/servers/controller/common/common.go +++ b/internal/servers/controller/common/common.go @@ -24,5 +24,6 @@ type ( PluginHostRepoFactory func() (*pluginhost.Repository, error) HostPluginRepoFactory func() (*hostplugin.Repository, error) SessionRepoFactory func() (*session.Repository, error) + ConnectionRepoFactory func() (*session.ConnectionRepository, error) TargetRepoFactory func() (*target.Repository, error) ) diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index 9fe2b2aa62..0268302c79 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -68,6 +68,7 @@ type Controller struct { PasswordAuthRepoFn common.PasswordAuthRepoFactory ServersRepoFn common.ServersRepoFactory SessionRepoFn common.SessionRepoFactory + ConnectionRepoFn common.ConnectionRepoFactory StaticHostRepoFn common.StaticRepoFactory PluginHostRepoFn common.PluginHostRepoFactory HostPluginRepoFn common.HostPluginRepoFactory @@ -235,7 +236,9 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { c.SessionRepoFn = func() (*session.Repository, error) { return session.NewRepository(dbase, dbase, c.kms) } - + c.ConnectionRepoFn = func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, dbase, dbase, c.kms) + } return c, nil } @@ -290,21 +293,21 @@ func (c *Controller) registerJobs() error { return err } - if err := c.registerSessionCleanupJob(); err != nil { + if err := c.registerSessionConnectionCleanupJob(); err != nil { return err } return nil } -// registerSessionCleanupJob is a helper method to abstract -// registering the session cleanup job specifically. -func (c *Controller) registerSessionCleanupJob() error { - sessionCleanupJob, err := newSessionCleanupJob(c.SessionRepoFn, int(c.conf.StatusGracePeriodDuration.Seconds())) +// registerSessionConnectionCleanupJob is a helper method to abstract +// registering the session connection cleanup job specifically. +func (c *Controller) registerSessionConnectionCleanupJob() error { + sessionConnectionCleanupJob, err := newSessionConnectionCleanupJob(c.ConnectionRepoFn, c.conf.StatusGracePeriodDuration) if err != nil { return fmt.Errorf("error creating session cleanup job: %w", err) } - if err = c.scheduler.RegisterJob(c.baseContext, sessionCleanupJob); err != nil { + if err = c.scheduler.RegisterJob(c.baseContext, sessionConnectionCleanupJob); err != nil { return fmt.Errorf("error registering session cleanup job: %w", err) } diff --git a/internal/servers/controller/handlers/targets/tcp/target_service_test.go b/internal/servers/controller/handlers/targets/tcp/target_service_test.go index 718c8d3cac..d5392d2a37 100644 --- a/internal/servers/controller/handlers/targets/tcp/target_service_test.go +++ b/internal/servers/controller/handlers/targets/tcp/target_service_test.go @@ -2628,6 +2628,9 @@ func TestAuthorizeSession(t *testing.T) { sessionRepoFn := func() (*session.Repository, error) { return session.NewRepository(rw, rw, kms) } + connectionRepoFn := func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, rw, rw, kms) + } staticHostRepoFn := func() (*static.Repository, error) { return static.NewRepository(rw, rw, kms) } @@ -2758,7 +2761,7 @@ func TestAuthorizeSession(t *testing.T) { require.NoError(t, err) // Tell our DB that there is a worker ready to serve the data - workerService := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, &sync.Map{}, kms) + workerService := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, connectionRepoFn, &sync.Map{}, kms) _, err = workerService.Status(ctx, &spbs.StatusRequest{ Worker: &spb.Server{ PrivateId: "testworker", @@ -2874,6 +2877,9 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { sessionRepoFn := func() (*session.Repository, error) { return session.NewRepository(rw, rw, kms) } + connectionRepoFn := func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, rw, rw, kms) + } staticHostRepoFn := func() (*static.Repository, error) { return static.NewRepository(rw, rw, kms) } @@ -3037,7 +3043,7 @@ func TestAuthorizeSessionTypedCredentials(t *testing.T) { require.NoError(t, err) // Tell our DB that there is a worker ready to serve the data - workerService := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, &sync.Map{}, kms) + workerService := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, connectionRepoFn, &sync.Map{}, kms) _, err = workerService.Status(ctx, &spbs.StatusRequest{ Worker: &spb.Server{ PrivateId: "testworker", @@ -3126,6 +3132,9 @@ func TestAuthorizeSession_Errors(t *testing.T) { sessionRepoFn := func() (*session.Repository, error) { return session.NewRepository(rw, rw, kms) } + connectionRepoFn := func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, rw, rw, kms) + } staticHostRepoFn := func() (*static.Repository, error) { return static.NewRepository(rw, rw, kms) } @@ -3165,7 +3174,7 @@ func TestAuthorizeSession_Errors(t *testing.T) { store := vault.TestCredentialStore(t, conn, wrapper, proj.GetPublicId(), v.Addr, tok, sec.Auth.Accessor) workerExists := func(tar target.Target) (version uint32) { - workerService := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, &sync.Map{}, kms) + workerService := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, connectionRepoFn, &sync.Map{}, kms) _, err := workerService.Status(context.Background(), &spbs.StatusRequest{ Worker: &spb.Server{ PrivateId: "testworker", diff --git a/internal/servers/controller/handlers/workers/worker_service.go b/internal/servers/controller/handlers/workers/worker_service.go index 29e250c459..bb489a013b 100644 --- a/internal/servers/controller/handlers/workers/worker_service.go +++ b/internal/servers/controller/handlers/workers/worker_service.go @@ -26,22 +26,25 @@ type workerServiceServer struct { pbs.UnimplementedServerCoordinationServiceServer pbs.UnimplementedSessionServiceServer - serversRepoFn common.ServersRepoFactory - sessionRepoFn common.SessionRepoFactory - updateTimes *sync.Map - kms *kms.Kms + serversRepoFn common.ServersRepoFactory + sessionRepoFn common.SessionRepoFactory + connectionRepoFn common.ConnectionRepoFactory + updateTimes *sync.Map + kms *kms.Kms } func NewWorkerServiceServer( serversRepoFn common.ServersRepoFactory, sessionRepoFn common.SessionRepoFactory, + connectionRepoFn common.ConnectionRepoFactory, updateTimes *sync.Map, kms *kms.Kms) *workerServiceServer { return &workerServiceServer{ - serversRepoFn: serversRepoFn, - sessionRepoFn: sessionRepoFn, - updateTimes: updateTimes, - kms: kms, + serversRepoFn: serversRepoFn, + sessionRepoFn: sessionRepoFn, + connectionRepoFn: connectionRepoFn, + updateTimes: updateTimes, + kms: kms, } } @@ -61,6 +64,7 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques return &pbs.StatusResponse{}, status.Errorf(codes.Internal, "Error acquiring repo to store worker status: %v", err) } sessRepo, err := ws.sessionRepoFn() + connectionRepo, err := ws.connectionRepoFn() if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("error getting sessions repo")) return &pbs.StatusResponse{}, status.Errorf(codes.Internal, "Error acquiring repo to query session status: %v", err) @@ -75,47 +79,16 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques Controllers: controllers, } - var ( - // For tracking the reported open connections. - reportedOpenConns []string - // For tracking the session IDs we've already requested - // cancellation for. We won't need to add connection cancel - // requests for these because canceling the session terminates the - // connections. - requestedSessionCancelIds []string - ) - - // This is a map of all sessions and their statuses. We keep track of - // this for easy lookup if we need to make change requests. - sessionStatuses := make(map[string]pbs.SESSIONSTATUS) + stateReport := make([]session.StateReport, 0, len(req.GetJobs())) for _, jobStatus := range req.GetJobs() { switch jobStatus.Job.GetType() { - // Check for session cancellation case pbs.JOBTYPE_JOBTYPE_SESSION: si := jobStatus.GetJob().GetSessionInfo() if si == nil { return nil, status.Error(codes.Internal, "Error getting session info at status time") } - // Record status. - sessionStatuses[si.GetSessionId()] = si.Status - - // Check connections before potentially bypassing the rest of the - // logic in the switch on si.Status. - sessConns := si.GetConnections() - for _, conn := range sessConns { - switch conn.Status { - case pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_AUTHORIZED, - pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED: - // If it's active, report it as found. Otherwise don't - // report as found, so that we should attempt to close it. - // Note that unspecified is the default state for the enum - // but it's not ever explicitly set by us. - reportedOpenConns = append(reportedOpenConns, conn.GetConnectionId()) - } - } - switch si.Status { case pbs.SESSIONSTATUS_SESSIONSTATUS_CANCELING, pbs.SESSIONSTATUS_SESSIONSTATUS_TERMINATED: @@ -123,77 +96,40 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques continue } - sessionId := si.GetSessionId() - sessionInfo, _, err := sessRepo.LookupSession(ctx, sessionId) - if err != nil { - return nil, status.Errorf(codes.Internal, "Error looking up session with id %s: %v", sessionId, err) - } - if sessionInfo == nil { - return nil, status.Errorf(codes.Internal, "Unknown session ID %s at status time.", sessionId) + sr := session.StateReport{ + SessionId: si.GetSessionId(), + ConnectionIds: make([]string, 0, len(si.GetConnections())), } - if len(sessionInfo.States) == 0 { - return nil, status.Error(codes.Internal, "Empty session states during lookup at status time.") - } - // If the session from the DB is in canceling status, and we're - // here, it means the job is in pending or active; cancel it. If - // it's in terminated status something went wrong and we're - // mismatched, so ensure we cancel it also. - currState := sessionInfo.States[0].Status - if currState.ProtoVal() != si.Status { - switch currState { - case session.StatusCanceling, - session.StatusTerminated: - // If we're here the job is pending or active so we do want - // to actually send a change request - ret.JobsRequests = append(ret.JobsRequests, &pbs.JobChangeRequest{ - Job: &pbs.Job{ - Type: pbs.JOBTYPE_JOBTYPE_SESSION, - JobInfo: &pbs.Job_SessionInfo{ - SessionInfo: &pbs.SessionJobInfo{ - SessionId: sessionId, - Status: currState.ProtoVal(), - }, - }, - }, - RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE, - }) - // Log the session ID so we don't add a duplicate change - // request on connection normalization. - requestedSessionCancelIds = append(requestedSessionCancelIds, sessionId) + for _, conn := range si.GetConnections() { + switch conn.Status { + case pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_AUTHORIZED, + pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED: + sr.ConnectionIds = append(sr.ConnectionIds, conn.GetConnectionId()) } } + stateReport = append(stateReport, sr) } } - // Normalize the current state of connections on the worker side - // with the data from the controller. In other words, if one of our - // found connections isn't supposed to be alive still, kill it. - // - // This is separate from the above session normalization and is - // additive to it, we don't add sessions that have already been - // added there as canceling sessions already closes the - // connections. - shouldCloseConnections, err := sessRepo.ShouldCloseConnectionsOnWorker(ctx, reportedOpenConns, requestedSessionCancelIds) + notActive, err := session.WorkerStatusReport(ctx, sessRepo, connectionRepo, req.Worker.PrivateId, stateReport) if err != nil { - return nil, status.Errorf(codes.Internal, "Error fetching connections that should be closed: %v", err) + return nil, status.Errorf(codes.Internal, "Error comparing state of sessions for worker: %s: %v", req.Worker.PrivateId, err) } - - for sessionId, connIds := range shouldCloseConnections { + for _, na := range notActive { var connChanges []*pbs.Connection - for _, connId := range connIds { + for _, connId := range na.ConnectionIds { connChanges = append(connChanges, &pbs.Connection{ ConnectionId: connId, Status: session.StatusClosed.ProtoVal(), }) } - ret.JobsRequests = append(ret.JobsRequests, &pbs.JobChangeRequest{ Job: &pbs.Job{ Type: pbs.JOBTYPE_JOBTYPE_SESSION, JobInfo: &pbs.Job_SessionInfo{ SessionInfo: &pbs.SessionJobInfo{ - SessionId: sessionId, - Status: sessionStatuses[sessionId], + SessionId: na.SessionId, + Status: na.Status.ProtoVal(), Connections: connChanges, }, }, @@ -202,15 +138,6 @@ func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusReques }) } - // Run our controller-side cleanup function. - closedConns, err := sessRepo.CloseDeadConnectionsForWorker(ctx, req.Worker.PrivateId, reportedOpenConns) - if err != nil { - return nil, status.Errorf(codes.Internal, "Error closing dead conns for worker %s: %v", req.Worker.PrivateId, err) - } - if closedConns > 0 { - event.WriteSysEvent(ctx, op, "marked unclaimed connections as closed", "server_id", req.Worker.PrivateId, "count", closedConns) - } - return ret, nil } @@ -386,12 +313,17 @@ func (ws *workerServiceServer) ActivateSession(ctx context.Context, req *pbs.Act func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs.AuthorizeConnectionRequest) (*pbs.AuthorizeConnectionResponse, error) { const op = "workers.(workerServiceServer).AuthorizeConnection" - sessRepo, err := ws.sessionRepoFn() + connectionRepo, err := ws.connectionRepoFn() + if err != nil { + return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err) + } + + sessionRepo, err := ws.sessionRepoFn() if err != nil { return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err) } - connectionInfo, connStates, authzSummary, err := sessRepo.AuthorizeConnection(ctx, req.GetSessionId(), req.GetWorkerId()) + connectionInfo, connStates, authzSummary, err := session.AuthorizeConnection(ctx, sessionRepo, connectionRepo, req.GetSessionId(), req.GetWorkerId()) if err != nil { return nil, err } @@ -416,12 +348,12 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.ConnectConnectionRequest) (*pbs.ConnectConnectionResponse, error) { const op = "workers.(workerServiceServer).ConnectConnection" - sessRepo, err := ws.sessionRepoFn() + connRepo, err := ws.connectionRepoFn() if err != nil { return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err) } - connectionInfo, connStates, err := sessRepo.ConnectConnection(ctx, session.ConnectWith{ + connectionInfo, connStates, err := connRepo.ConnectConnection(ctx, session.ConnectWith{ ConnectionId: req.GetConnectionId(), ClientTcpAddress: req.GetClientTcpAddress(), ClientTcpPort: req.GetClientTcpPort(), @@ -460,12 +392,17 @@ func (ws *workerServiceServer) CloseConnection(ctx context.Context, req *pbs.Clo ClosedReason: session.ClosedReason(v.GetReason()), }) } + connRepo, err := ws.connectionRepoFn() + if err != nil { + return nil, status.Errorf(codes.Internal, "error getting connection repo: %v", err) + } + sessRepo, err := ws.sessionRepoFn() if err != nil { return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err) } - closeInfos, err := sessRepo.CloseConnections(ctx, closeWiths) + closeInfos, err := session.CloseConnections(ctx, sessRepo, connRepo, closeWiths) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/workers/worker_service_status_test.go b/internal/servers/controller/handlers/workers/worker_service_status_test.go new file mode 100644 index 0000000000..c1efb3ba79 --- /dev/null +++ b/internal/servers/controller/handlers/workers/worker_service_status_test.go @@ -0,0 +1,509 @@ +package workers_test + +import ( + "context" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/hashicorp/boundary/internal/authtoken" + "github.com/hashicorp/boundary/internal/db" + pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + "github.com/hashicorp/boundary/internal/host/static" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/servers" + "github.com/hashicorp/boundary/internal/servers/controller/handlers/workers" + "github.com/hashicorp/boundary/internal/session" + "github.com/hashicorp/boundary/internal/target" + "github.com/hashicorp/boundary/internal/target/tcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatus(t *testing.T) { + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + + serverRepo, _ := servers.NewRepository(rw, rw, kms) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_worker1", + Type: "worker", + Address: "127.0.0.1", + }) + + serversRepoFn := func() (*servers.Repository, error) { + return serverRepo, nil + } + sessionRepoFn := func() (*session.Repository, error) { + return session.NewRepository(rw, rw, kms) + } + connRepoFn := func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, rw, rw, kms) + } + + repo, err := sessionRepoFn() + require.NoError(t, err) + connRepo, err := connRepoFn() + require.NoError(t, err) + + at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) + uId := at.GetIamUserId() + hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0] + hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0] + h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0] + static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h}) + tar := tcp.TestTarget( + ctx, + t, conn, prj.GetPublicId(), "test", + target.WithHostSources([]string{hs.GetPublicId()}), + target.WithSessionConnectionLimit(-1), + ) + + worker1 := session.TestWorker(t, conn, wrapper) + + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker1.PrivateId, worker1.Type, tofu) + require.NoError(t, err) + require.NoError(t, err) + + s := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, connRepoFn, new(sync.Map), kms) + require.NotNil(t, s) + + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PrivateId) + require.NoError(t, err) + + cases := []struct { + name string + wantErr bool + wantErrMsg string + req *pbs.StatusRequest + want *pbs.StatusResponse + }{ + { + name: "No Sessions", + wantErr: false, + req: &pbs.StatusRequest{ + Worker: worker1, + }, + want: &pbs.StatusResponse{ + Controllers: []*servers.Server{ + { + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }, + }, + }, + }, + { + name: "Still Active", + wantErr: false, + req: &pbs.StatusRequest{ + Worker: worker1, + Jobs: []*pbs.JobStatus{ + { + Job: &pbs.Job{ + Type: pbs.JOBTYPE_JOBTYPE_SESSION, + JobInfo: &pbs.Job_SessionInfo{ + SessionInfo: &pbs.SessionJobInfo{ + SessionId: sess.PublicId, + Status: pbs.SESSIONSTATUS_SESSIONSTATUS_ACTIVE, + Connections: []*pbs.Connection{ + { + ConnectionId: connection.PublicId, + Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED, + }, + }, + }, + }, + }, + }, + }, + }, + want: &pbs.StatusResponse{ + Controllers: []*servers.Server{ + { + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }, + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + got, err := s.Status(ctx, tc.req) + if tc.wantErr { + require.Error(err) + assert.Nil(got) + assert.Equal(tc.wantErrMsg, err.Error()) + return + } + assert.Empty( + cmp.Diff( + tc.want, + got, + cmpopts.IgnoreUnexported( + pbs.StatusResponse{}, + servers.Server{}, + pbs.JobChangeRequest{}, + pbs.Job{}, + pbs.Job_SessionInfo{}, + pbs.SessionJobInfo{}, + pbs.Connection{}, + ), + cmpopts.IgnoreFields(servers.Server{}, "CreateTime", "UpdateTime"), + ), + ) + }) + } +} + +func TestStatusSessionClosed(t *testing.T) { + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + + serverRepo, _ := servers.NewRepository(rw, rw, kms) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_worker1", + Type: "worker", + Address: "127.0.0.1", + }) + + serversRepoFn := func() (*servers.Repository, error) { + return serverRepo, nil + } + sessionRepoFn := func() (*session.Repository, error) { + return session.NewRepository(rw, rw, kms) + } + connRepoFn := func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, rw, rw, kms) + } + + repo, err := sessionRepoFn() + require.NoError(t, err) + connRepo, err := connRepoFn() + require.NoError(t, err) + + at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) + uId := at.GetIamUserId() + hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0] + hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0] + h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0] + static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h}) + tar := tcp.TestTarget( + ctx, + t, conn, prj.GetPublicId(), "test", + target.WithHostSources([]string{hs.GetPublicId()}), + target.WithSessionConnectionLimit(-1), + ) + + worker1 := session.TestWorker(t, conn, wrapper) + + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker1.PrivateId, worker1.Type, tofu) + require.NoError(t, err) + sess2 := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu2 := session.TestTofu(t) + sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, worker1.PrivateId, worker1.Type, tofu2) + require.NoError(t, err) + + s := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, connRepoFn, new(sync.Map), kms) + require.NotNil(t, s) + + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PrivateId) + require.NoError(t, err) + + cases := []struct { + name string + wantErr bool + wantErrMsg string + setupFn func(t *testing.T) + req *pbs.StatusRequest + want *pbs.StatusResponse + }{ + { + name: "Connection Canceled", + wantErr: false, + setupFn: func(t *testing.T) { + _, err := repo.CancelSession(ctx, sess2.PublicId, sess.Version) + require.NoError(t, err) + }, + req: &pbs.StatusRequest{ + Worker: worker1, + Jobs: []*pbs.JobStatus{ + { + Job: &pbs.Job{ + Type: pbs.JOBTYPE_JOBTYPE_SESSION, + JobInfo: &pbs.Job_SessionInfo{ + SessionInfo: &pbs.SessionJobInfo{ + SessionId: sess2.PublicId, + Status: pbs.SESSIONSTATUS_SESSIONSTATUS_ACTIVE, + Connections: []*pbs.Connection{ + { + ConnectionId: connection.PublicId, + Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED, + }, + }, + }, + }, + }, + }, + }, + }, + want: &pbs.StatusResponse{ + Controllers: []*servers.Server{ + { + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }, + }, + JobsRequests: []*pbs.JobChangeRequest{ + { + Job: &pbs.Job{ + Type: pbs.JOBTYPE_JOBTYPE_SESSION, + JobInfo: &pbs.Job_SessionInfo{ + SessionInfo: &pbs.SessionJobInfo{ + SessionId: sess2.PublicId, + Status: pbs.SESSIONSTATUS_SESSIONSTATUS_CANCELING, + }, + }, + }, + RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE, + }, + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + if tc.setupFn != nil { + tc.setupFn(t) + } + got, err := s.Status(ctx, tc.req) + if tc.wantErr { + require.Error(err) + assert.Nil(got) + assert.Equal(tc.wantErrMsg, err.Error()) + return + } + assert.Empty( + cmp.Diff( + tc.want, + got, + cmpopts.IgnoreUnexported( + pbs.StatusResponse{}, + servers.Server{}, + pbs.JobChangeRequest{}, + pbs.Job{}, + pbs.Job_SessionInfo{}, + pbs.SessionJobInfo{}, + pbs.Connection{}, + ), + cmpopts.IgnoreFields(servers.Server{}, "CreateTime", "UpdateTime"), + ), + ) + }) + } +} + +func TestStatusDeadConnection(t *testing.T) { + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + + serverRepo, _ := servers.NewRepository(rw, rw, kms) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_worker1", + Type: "worker", + Address: "127.0.0.1", + }) + + serversRepoFn := func() (*servers.Repository, error) { + return serverRepo, nil + } + sessionRepoFn := func() (*session.Repository, error) { + return session.NewRepository(rw, rw, kms) + } + connRepoFn := func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, rw, rw, kms, session.WithWorkerStateDelay(0)) + } + + repo, err := sessionRepoFn() + require.NoError(t, err) + connRepo, err := connRepoFn() + require.NoError(t, err) + + at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) + uId := at.GetIamUserId() + hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0] + hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0] + h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0] + static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h}) + tar := tcp.TestTarget( + ctx, + t, conn, prj.GetPublicId(), "test", + target.WithHostSources([]string{hs.GetPublicId()}), + target.WithSessionConnectionLimit(-1), + ) + + worker1 := session.TestWorker(t, conn, wrapper) + + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker1.PrivateId, worker1.Type, tofu) + require.NoError(t, err) + sess2 := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu2 := session.TestTofu(t) + sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, worker1.PrivateId, worker1.Type, tofu2) + require.NoError(t, err) + + s := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, connRepoFn, new(sync.Map), kms) + require.NotNil(t, s) + + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PrivateId) + require.NoError(t, err) + deadConn, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker1.PrivateId) + require.NoError(t, err) + require.NotEqual(t, deadConn.PublicId, connection.PublicId) + + req := &pbs.StatusRequest{ + Worker: worker1, + Jobs: []*pbs.JobStatus{ + { + Job: &pbs.Job{ + Type: pbs.JOBTYPE_JOBTYPE_SESSION, + JobInfo: &pbs.Job_SessionInfo{ + SessionInfo: &pbs.SessionJobInfo{ + SessionId: sess.PublicId, + Status: pbs.SESSIONSTATUS_SESSIONSTATUS_ACTIVE, + Connections: []*pbs.Connection{ + { + ConnectionId: connection.PublicId, + Status: pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED, + }, + }, + }, + }, + }, + }, + }, + } + want := &pbs.StatusResponse{ + Controllers: []*servers.Server{ + { + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }, + }, + } + + got, err := s.Status(ctx, req) + assert.Empty(t, + cmp.Diff( + want, + got, + cmpopts.IgnoreUnexported( + pbs.StatusResponse{}, + servers.Server{}, + pbs.JobChangeRequest{}, + pbs.Job{}, + pbs.Job_SessionInfo{}, + pbs.SessionJobInfo{}, + pbs.Connection{}, + ), + cmpopts.IgnoreFields(servers.Server{}, "CreateTime", "UpdateTime"), + ), + ) + + gotConn, states, err := connRepo.LookupConnection(ctx, deadConn.PublicId) + require.NoError(t, err) + assert.Equal(t, session.ConnectionSystemError, session.ClosedReason(gotConn.ClosedReason)) + assert.Equal(t, 2, len(states)) + assert.Nil(t, states[0].EndTime) + assert.Equal(t, session.StatusClosed, states[0].Status) +} diff --git a/internal/servers/controller/handlers/workers/worker_service_test.go b/internal/servers/controller/handlers/workers/worker_service_test.go index 50cd2e1a52..025ea901c7 100644 --- a/internal/servers/controller/handlers/workers/worker_service_test.go +++ b/internal/servers/controller/handlers/workers/worker_service_test.go @@ -37,6 +37,9 @@ func TestLookupSession(t *testing.T) { sessionRepoFn := func() (*session.Repository, error) { return session.NewRepository(rw, rw, kms) } + connectionRepoFn := func() (*session.ConnectionRepository, error) { + return session.NewConnectionRepository(ctx, rw, rw, kms) + } at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) uId := at.GetIamUserId() @@ -97,7 +100,7 @@ func TestLookupSession(t *testing.T) { err = repo.AddSessionCredentials(ctx, egressSess.ScopeId, egressSess.GetPublicId(), workerCreds) require.NoError(t, err) - s := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, new(sync.Map), kms) + s := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, connectionRepoFn, new(sync.Map), kms) require.NotNil(t, s) cases := []struct { diff --git a/internal/servers/controller/listeners.go b/internal/servers/controller/listeners.go index e72e370297..2079018ded 100644 --- a/internal/servers/controller/listeners.go +++ b/internal/servers/controller/listeners.go @@ -123,7 +123,8 @@ func (c *Controller) startListeners(ctx context.Context) error { ), ), ) - workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.workerStatusUpdateTimes, c.kms) + workerService := workers.NewWorkerServiceServer(c.ServersRepoFn, c.SessionRepoFn, c.ConnectionRepoFn, + c.workerStatusUpdateTimes, c.kms) pbs.RegisterServerCoordinationServiceServer(workerServer, workerService) pbs.RegisterSessionServiceServer(workerServer, workerService) diff --git a/internal/servers/controller/multi_test.go b/internal/servers/controller/multi_test.go index bbe874a9eb..c2d7cec213 100644 --- a/internal/servers/controller/multi_test.go +++ b/internal/servers/controller/multi_test.go @@ -3,7 +3,6 @@ package controller_test import ( "encoding/json" "testing" - "time" "github.com/hashicorp/boundary/api/authmethods" "github.com/hashicorp/boundary/api/authtokens" @@ -38,7 +37,6 @@ func TestAuthenticationMulti(t *testing.T) { require.NoError(json.Unmarshal(token1Result.GetRawAttributes(), token1)) require.NotNil(token1) - time.Sleep(5 * time.Second) auth = authmethods.NewClient(c2.Client()) token2Result, err := auth.Authenticate(c2.Context(), c2.Server().DevPasswordAuthMethodId, "login", map[string]interface{}{"login_name": c2.Server().DevLoginName, "password": c2.Server().DevPassword}) require.Nil(err) diff --git a/internal/servers/controller/session_cleanup_job.go b/internal/servers/controller/session_cleanup_job.go index a74d92ce34..1d63e22e8c 100644 --- a/internal/servers/controller/session_cleanup_job.go +++ b/internal/servers/controller/session_cleanup_job.go @@ -13,7 +13,7 @@ import ( "github.com/hashicorp/boundary/internal/session" ) -// sessionCleanupJob defines a periodic job that monitors workers for +// sessionConnectionCleanupJob defines a periodic job that monitors workers for // loss of connection and terminates connections on workers that have // not sent a heartbeat in a significant period of time. // @@ -22,58 +22,58 @@ import ( // worker, or the event of a synchronization issue between the two, // the controller will win out and order that the connections be // closed on the worker. -type sessionCleanupJob struct { - sessionRepoFn common.SessionRepoFactory +type sessionConnectionCleanupJob struct { + connectionRepoFn common.ConnectionRepoFactory // The amount of time to give disconnected workers before marking // their connections as closed. - gracePeriod int + gracePeriod time.Duration // The total number of connections closed in the last run. totalClosed int } -// newSessionCleanupJob instantiates the session cleanup job. -func newSessionCleanupJob( - sessionRepoFn common.SessionRepoFactory, - gracePeriod int, -) (*sessionCleanupJob, error) { - const op = "controller.newNewSessionCleanupJob" +// newSessionConnectionCleanupJob instantiates the session cleanup job. +func newSessionConnectionCleanupJob( + connectionRepoFn common.ConnectionRepoFactory, + gracePeriod time.Duration, +) (*sessionConnectionCleanupJob, error) { + const op = "controller.newNewSessionConnectionCleanupJob" switch { - case sessionRepoFn == nil: - return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing sessionRepoFn") + case connectionRepoFn == nil: + return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing connectionRepoFn") case gracePeriod < session.DeadWorkerConnCloseMinGrace: return nil, errors.NewDeprecated( - errors.InvalidParameter, op, fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace)) + errors.InvalidParameter, op, fmt.Sprintf("invalid gracePeriod, must be greater than %s", session.DeadWorkerConnCloseMinGrace)) } - return &sessionCleanupJob{ - sessionRepoFn: sessionRepoFn, - gracePeriod: gracePeriod, + return &sessionConnectionCleanupJob{ + connectionRepoFn: connectionRepoFn, + gracePeriod: gracePeriod, }, nil } // Name returns a short, unique name for the job. -func (j *sessionCleanupJob) Name() string { return "session_cleanup" } +func (j *sessionConnectionCleanupJob) Name() string { return "session_cleanup" } // Description returns the description for the job. -func (j *sessionCleanupJob) Description() string { +func (j *sessionConnectionCleanupJob) Description() string { return "Clean up session connections from disconnected workers" } // NextRunIn returns the next run time after a job is completed. // -// The next run time is defined for sessionCleanupJob as one second. +// The next run time is defined for sessionConnectionCleanupJob as one second. // This is because the job should run continuously to terminate // connections as soon as a worker has not reported in for a long // enough time. Only one job will ever run at once, so there is no // reason why it cannot run again immediately. -func (j *sessionCleanupJob) NextRunIn(_ context.Context) (time.Duration, error) { +func (j *sessionConnectionCleanupJob) NextRunIn(_ context.Context) (time.Duration, error) { return time.Second, nil } // Status returns the status of the running job. -func (j *sessionCleanupJob) Status() scheduler.JobStatus { +func (j *sessionConnectionCleanupJob) Status() scheduler.JobStatus { return scheduler.JobStatus{ Completed: j.totalClosed, Total: j.totalClosed, @@ -81,18 +81,18 @@ func (j *sessionCleanupJob) Status() scheduler.JobStatus { } // Run executes the job. -func (j *sessionCleanupJob) Run(ctx context.Context) error { - const op = "controller.(sessionCleanupJob).Run" +func (j *sessionConnectionCleanupJob) Run(ctx context.Context) error { + const op = "controller.(sessionConnectionCleanupJob).Run" j.totalClosed = 0 // Load repos. - sessionRepo, err := j.sessionRepoFn() + connectionRepo, err := j.connectionRepoFn() if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("error getting session repo")) } // Run the atomic dead worker cleanup job. - results, err := sessionRepo.CloseConnectionsForDeadWorkers(ctx, j.gracePeriod) + results, err := connectionRepo.CloseConnectionsForDeadWorkers(ctx, j.gracePeriod) if err != nil { return errors.Wrap(ctx, err, op) } @@ -102,7 +102,7 @@ func (j *sessionCleanupJob) Run(ctx context.Context) error { event.WithInfo( "private_id", result.ServerId, "update_time", result.LastUpdateTime, - "grace_period_seconds", j.gracePeriod, + "grace_period_seconds", j.gracePeriod.Seconds(), "number_connections_closed", result.NumberConnectionsClosed, )) diff --git a/internal/servers/controller/session_cleanup_job_test.go b/internal/servers/controller/session_cleanup_job_test.go index cc679678ab..094a0602d3 100644 --- a/internal/servers/controller/session_cleanup_job_test.go +++ b/internal/servers/controller/session_cleanup_job_test.go @@ -18,13 +18,17 @@ import ( ) // assert the interface -var _ = scheduler.Job(new(sessionCleanupJob)) +var _ = scheduler.Job(new(sessionConnectionCleanupJob)) // This test has been largely adapted from // TestRepository_CloseDeadConnectionsOnWorker in // internal/session/repository_connection_test.go. -func TestSessionCleanupJob(t *testing.T) { +func TestSessionConnectionCleanupJob(t *testing.T) { t.Parallel() + ctx := context.Background() + + const gracePeriod = 1 * time.Second + require, assert := require.New(t), assert.New(t) conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) @@ -34,8 +38,9 @@ func TestSessionCleanupJob(t *testing.T) { serversRepo, err := servers.NewRepository(rw, rw, kms) require.NoError(err) sessionRepo, err := session.NewRepository(rw, rw, kms) + connectionRepo, err := session.NewConnectionRepository(ctx, rw, rw, kms, session.WithDeadWorkerConnCloseMinGrace(gracePeriod)) require.NoError(err) - ctx := context.Background() + numConns := 12 // Create two "workers". One will remain untouched while the other "goes @@ -54,7 +59,7 @@ func TestSessionCleanupJob(t *testing.T) { sess := session.TestDefaultSession(t, conn, wrapper, iamRepo, session.WithServerId(serverId), session.WithDbOpts(db.WithSkipVetForWrite(true))) sess, _, err = sessionRepo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo")) require.NoError(err) - c, cs, _, err := sessionRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + c, cs, _, err := session.AuthorizeConnection(ctx, sessionRepo, connectionRepo, sess.GetPublicId(), serverId) require.NoError(err) require.Len(cs, 1) require.Equal(session.StatusAuthorized, cs[0].Status) @@ -70,7 +75,7 @@ func TestSessionCleanupJob(t *testing.T) { // This is just to ensure we have a spread when we test it out. for i, connId := range connIds { if i%2 == 0 { - _, cs, err := sessionRepo.ConnectConnection(ctx, session.ConnectWith{ + _, cs, err := connectionRepo.ConnectConnection(ctx, session.ConnectWith{ ConnectionId: connId, ClientTcpAddress: "127.0.0.1", ClientTcpPort: 22, @@ -95,14 +100,15 @@ func TestSessionCleanupJob(t *testing.T) { } // Create the job. - job, err := newSessionCleanupJob( - func() (*session.Repository, error) { return sessionRepo, nil }, + job, err := newSessionConnectionCleanupJob( + func() (*session.ConnectionRepository, error) { return connectionRepo, nil }, session.DeadWorkerConnCloseMinGrace, ) + job.gracePeriod = gracePeriod // by-pass factory assert so we dont have to wait so long require.NoError(err) // sleep the status grace period. - time.Sleep(time.Second * time.Duration(session.DeadWorkerConnCloseMinGrace)) + time.Sleep(gracePeriod) // Push an upsert to the first worker so that its status has been // updated. @@ -119,7 +125,7 @@ func TestSessionCleanupJob(t *testing.T) { require.True(ok) require.Len(connIds, 6) for _, connId := range connIds { - _, states, err := sessionRepo.LookupConnection(ctx, connId, nil) + _, states, err := connectionRepo.LookupConnection(ctx, connId, nil) require.NoError(err) var foundClosed bool for _, state := range states { @@ -138,27 +144,27 @@ func TestSessionCleanupJob(t *testing.T) { assertConnections(worker1.PrivateId, false) } -func TestSessionCleanupJobNewJobErr(t *testing.T) { +func TestSessionConnectionCleanupJobNewJobErr(t *testing.T) { t.Parallel() ctx := context.TODO() - const op = "controller.newNewSessionCleanupJob" + const op = "controller.newNewSessionConnectionCleanupJob" require := require.New(t) - job, err := newSessionCleanupJob(nil, 0) + job, err := newSessionConnectionCleanupJob(nil, 0) require.Equal(err, errors.E( ctx, errors.WithCode(errors.InvalidParameter), errors.WithOp(op), - errors.WithMsg("missing sessionRepoFn"), + errors.WithMsg("missing connectionRepoFn"), )) require.Nil(job) - job, err = newSessionCleanupJob(func() (*session.Repository, error) { return nil, nil }, 0) + job, err = newSessionConnectionCleanupJob(func() (*session.ConnectionRepository, error) { return nil, nil }, 0) require.Equal(err, errors.E( ctx, errors.WithCode(errors.InvalidParameter), errors.WithOp(op), - errors.WithMsg(fmt.Sprintf("invalid gracePeriod, must be greater than %d", session.DeadWorkerConnCloseMinGrace)), + errors.WithMsg(fmt.Sprintf("invalid gracePeriod, must be greater than %s", session.DeadWorkerConnCloseMinGrace)), )) require.Nil(job) } diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 3e97bfd3c0..db9b17ed1a 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/hashicorp/boundary/internal/session" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/authmethods" @@ -116,6 +118,14 @@ func (tc *TestController) ServersRepo() *servers.Repository { return repo } +func (tc *TestController) ConnectionsRepo() *session.ConnectionRepository { + repo, err := tc.c.ConnectionRepoFn() + if err != nil { + tc.t.Fatal(err) + } + return repo +} + func (tc *TestController) Cancel() { tc.cancel() } diff --git a/internal/servers/worker/status.go b/internal/servers/worker/status.go index 2ae48a3aec..45cb4496c0 100644 --- a/internal/servers/worker/status.go +++ b/internal/servers/worker/status.go @@ -153,7 +153,7 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { // don't have any sessions to worry about anyway. // // If a length of time has passed since we've been able to communicate, we - // want to start terminating all sessions as a "break glass" kind of + // want to start terminating all connections as a "break glass" kind of // scenario, as there will be no way we can really tell if these // connections should continue to exist. @@ -163,9 +163,22 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) { event.WithInfo("last_status_time", lastStatusTime.String(), "grace_period", gracePeriod), ) - // Run a "cleanup" for all sessions that will not be caught by - // our standard cleanup routine. - w.cleanupConnections(cancelCtx, true) + // Cancel connections if grace period has expired. These Connections will be closed in the + // database on the next successful status report, or via the Controller’s dead Worker cleanup connections job. + w.sessionInfoMap.Range(func(key, value interface{}) bool { + si := value.(*session.Info) + si.Lock() + defer si.Unlock() + + closedIds := w.cancelConnections(si.ConnInfoMap, true) + for _, connId := range closedIds { + event.WriteSysEvent(cancelCtx, op, "terminated connection due to status grace period expiration", "session_id", si.Id, "connection_id", connId) + } + return true + }) + + // Exit out of status function; our work here is done and we don't need to create closeConnection requests + return } } else { w.updateTags.Store(false) diff --git a/internal/session/options.go b/internal/session/options.go index 79fcac5bf6..715d67cd74 100644 --- a/internal/session/options.go +++ b/internal/session/options.go @@ -1,6 +1,8 @@ package session import ( + "time" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" ) @@ -19,20 +21,25 @@ type Option func(*options) // options = how options are represented type options struct { - withLimit int - withOrderByCreateTime db.OrderBy - withScopeIds []string - withUserId string - withExpirationTime *timestamp.Timestamp - withTestTofu []byte - withListingConvert bool - withSessionIds []string - withServerId string - withDbOpts []db.Option + withLimit int + withOrderByCreateTime db.OrderBy + withScopeIds []string + withUserId string + withExpirationTime *timestamp.Timestamp + withTestTofu []byte + withListingConvert bool + withSessionIds []string + withServerId string + withDbOpts []db.Option + withWorkerStateDelay time.Duration + withDeadWorkerConnCloseMinGrace time.Duration } func getDefaultOptions() options { - return options{} + return options{ + withWorkerStateDelay: 10 * time.Second, + withDeadWorkerConnCloseMinGrace: DeadWorkerConnCloseMinGrace, + } } // WithLimit provides an option to provide a limit. Intentionally allowing @@ -108,3 +115,20 @@ func WithDbOpts(opts ...db.Option) Option { o.withDbOpts = opts } } + +// WithWorkerStateDelay is used by queries to account for a delay in state +// propagation between worker and controller. +func WithWorkerStateDelay(d time.Duration) Option { + return func(o *options) { + o.withWorkerStateDelay = d + } +} + +// WithDeadWorkerConnCloseMinGrace is used to set the minimum allowable setting +// for the CloseConnectionsForDeadWorkers method. This defaults to the default +// server liveness setting. +func WithDeadWorkerConnCloseMinGrace(d time.Duration) Option { + return func(o *options) { + o.withDeadWorkerConnCloseMinGrace = d + } +} diff --git a/internal/session/query.go b/internal/session/query.go index 0cbf1fe23b..430ae4a08d 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -10,15 +10,15 @@ const ( insert into session_state with not_active as ( select session_id, 'active' as state - from + from session s, session_state ss - where + where s.public_id = ss.session_id and - ss.state = 'pending' and - ss.session_id = @session_id and + ss.state = 'pending' and + ss.session_id = @session_id and s.version = @version and - s.public_id not in(select session_id from session_state where session_id = @session_id and state = 'active') + s.public_id not in(select session_id from session_state where session_id = @session_id and state = 'active') ) select * from not_active; ` @@ -27,105 +27,193 @@ select * from not_active; // state or it's not already terminated (final state) before inserting a new // state. updateSessionState = ` -insert into session_state(session_id, state) +insert into session_state(session_id, state) select - @session_id, @status + @session_id, @status from session s -where +where s.public_id = @session_id and s.public_id not in ( - select - session_id - from - session_state - where + select + session_id + from + session_state + where -- already in the updated state ( - session_id = @session_id and + session_id = @session_id and state = @status ) or -- already terminated session_id in ( - select - session_id - from - session_state - where - session_id = @session_id and + select + session_id + from + session_state + where + session_id = @session_id and state = 'terminated' ) - ) + ); ` authorizeConnectionCte = ` -insert into session_connection ( - session_id, - public_id, - server_id -) -with active_session as ( - select - @session_id as session_id, +with connections_available as ( + select + s.public_id + from + session s + where + s.public_id = @session_id and + (s.connection_limit = -1 or + s.connection_limit > (select count(*) from session_connection sc where sc.session_id = @session_id )) +), +unexpired_session as ( + select + s.public_id + from + session s + where + s.public_id in (select * from connections_available) and + s.expiration_time > now() +), +active_session as ( + select + ss.session_id as session_id, @public_id as public_id, @worker_id as server_id from - session s + session_state ss where - -- check that the session hasn't expired. - s.expiration_time > now() and - -- check that there are still connections available. connection_limit of -1 equals unlimited connections - ( - s.connection_limit = -1 - or - s.connection_limit > (select count(*) from session_connection sc where sc.session_id = @session_id) - ) and - -- check that there's a state of active - s.public_id in ( - select - ss.session_id - from - session_state ss - where - ss.session_id = @session_id and - ss.state = 'active' and - -- if there's no end_time, then this is the current state. - ss.end_time is null - ) + ss.session_id in (select * from unexpired_session) and + ss.state = 'active' and + ss.end_time is null +) +insert into session_connection ( + session_id, + public_id, + server_id ) select * from active_session; ` - remainingConnectionsCte = ` with session_connection_count(current_connection_count) as ( - select count(*) - from + select count(*) + from session_connection sc where sc.session_id = @session_id ), session_connection_limit(expiration_time, connection_limit) as ( - select + select s.expiration_time, s.connection_limit from session s - where + where s.public_id = @session_id ) -select expiration_time, connection_limit, current_connection_count -from - session_connection_limit, session_connection_count; +select expiration_time, connection_limit, current_connection_count +from + session_connection_limit, session_connection_count; ` sessionList = ` -select * +select * from (select public_id from session %s) s, session_list ss -where - s.public_id = ss.public_id +where + s.public_id = ss.public_id %s %s +; +` + + terminateSessionIfPossible = ` + -- is terminate_session_id in a canceling state + with session_version as ( + select + version + from + session + where public_id = @public_id + ), + canceling_session(session_id) as + ( + select + session_id + from + session_state ss + where + ss.session_id = @public_id and + ss.state = 'canceling' and + ss.end_time is null + ) + update session us + set version = version +1, + termination_reason = + case + -- timed out sessions + when now() > us.expiration_time then 'timed out' + -- canceling sessions + when us.public_id in( + select + session_id + from + canceling_session cs + where + us.public_id = cs.session_id + ) then 'canceled' + -- default: session connection limit reached. + else 'connection limit' + end + where + -- limit update to just the terminating_session_id + us.public_id = @public_id and + us.version = (select * from session_version) and + termination_reason is null and + -- session expired or connection limit reached + ( + -- expired sessions... + now() > us.expiration_time or + -- connection limit reached... + ( + -- handle unlimited connections... + connection_limit != -1 and + ( + select count (*) + from session_connection sc + where + sc.session_id = us.public_id + ) >= connection_limit + ) or + -- canceled sessions + us.public_id in ( + select + session_id + from + canceling_session cs + where + us.public_id = cs.session_id + ) + ) and + -- make sure there are no existing connections + us.public_id not in ( + select + session_id + from + session_connection + where public_id in ( + select + connection_id + from + session_connection_state + where + state != 'closed' and + end_time is null + ) + ) ` // termSessionUpdate is one stmt that terminates sessions for the following @@ -136,28 +224,28 @@ where termSessionsUpdate = ` with canceling_session(session_id) as ( - select + select session_id from session_state ss - where - ss.state = 'canceling' and + where + ss.state = 'canceling' and ss.end_time is null ) update session us - set termination_reason = - case + set termination_reason = + case -- timed out sessions when now() > us.expiration_time then 'timed out' -- canceling sessions when us.public_id in( - select - session_id - from - canceling_session cs - where + select + session_id + from + canceling_session cs + where us.public_id = cs.session_id - ) then 'canceled' + ) then 'canceled' -- default: session connection limit reached. else 'connection limit' end @@ -166,91 +254,46 @@ where -- session expired or connection limit reached ( -- expired sessions... - now() > us.expiration_time or + now() > us.expiration_time or -- connection limit reached... ( -- handle unlimited connections... connection_limit != -1 and ( - select count (*) - from session_connection sc - where + select count (*) + from session_connection sc + where sc.session_id = us.public_id ) >= connection_limit - ) or + ) or -- canceled sessions us.public_id in ( - select + select session_id from canceling_session cs - where - us.public_id = cs.session_id + where + us.public_id = cs.session_id ) - ) and + ) and -- make sure there are no existing connections us.public_id not in ( - select - session_id - from + select + session_id + from session_connection where public_id in ( - select + select connection_id - from + from session_connection_state - where + where state != 'closed' and end_time is null ) -) +); ` - // closeDeadConnectionsCte finds connections that are: - // - // * not closed - // * not announced by a given server in its latest update - // - // and marks them as closed. - closeDeadConnectionsCte = ` -with - -- Find connections that are not closed so we can reference those IDs - unclosed_connections as ( - select connection_id - from session_connection_state - where - -- It's the current state - end_time is null - and - -- Current state isn't closed state - state in ('authorized', 'connected') - and - -- It's not in limbo between when it moved into this state and when - -- it started being reported by the worker, which is roughly every - -- 2-3 seconds - start_time < wt_sub_seconds_from_now(10) - ), - connections_to_close as ( - select public_id - from session_connection - where - -- Related to the worker that just reported to us - server_id = ? - and - -- These are connection IDs that just got reported to us by the given - -- worker, so they should not be considered closed. - %s - -- Only unclosed ones - public_id in (select connection_id from unclosed_connections) - ) - update session_connection - set - closed_reason = 'system error' - where - public_id in (select public_id from connections_to_close) - returning public_id - ` - // closeConnectionsForDeadServersCte finds connections that are: // // * not closed @@ -262,13 +305,12 @@ with // The query returns the set of servers that have had connections closed // along with their last update time and the number of connections closed on // each. - closeConnectionsForDeadServersCte = ` with dead_servers (server_id, last_update_time) as ( select private_id, update_time from server - where update_time < wt_sub_seconds_from_now(?) + where update_time < wt_sub_seconds_from_now(@grace_period_seconds) ), closed_connections (connection_id, server_id) as ( update session_connection @@ -287,35 +329,50 @@ with order by closed_connections.server_id; ` - // shouldCloseConnectionsCte finds connections that are marked as closed in - // the database given a set of connection IDs. They are returned along with - // their associated session ID. - // - // The second parameter is a set of session IDs that we have already - // submitted a session-wide close request for, so sending another change - // request for them would be redundant. - shouldCloseConnectionsCte = ` + orphanedConnectionsCte = ` +-- Find connections that are not closed so we can reference those IDs with - -- Find connections that are closed so we can reference those IDs - closed_connections as ( + unclosed_connections as ( select connection_id from session_connection_state where -- It's the current state end_time is null - and - -- Current state is closed - state = 'closed' - and - connection_id in (%s) + -- Current state isn't closed state + and state in ('authorized', 'connected') + -- It's not in limbo between when it moved into this state and when + -- it started being reported by the worker, which is roughly every + -- 2-3 seconds + and start_time < wt_sub_seconds_from_now(@worker_state_delay_seconds) + ), + connections_to_close as ( + select public_id + from session_connection + where + -- Related to the worker that just reported to us + server_id = @server_id + -- Only unclosed ones + and public_id in (select connection_id from unclosed_connections) + -- These are connection IDs that just got reported to us by the given + -- worker, so they should not be considered closed. + %s ) - select public_id, session_id - from session_connection - where - public_id in (select connection_id from closed_connections) - -- Below fmt arg is filled in if there are session IDs to filter against - %s - ` +update session_connection + set + closed_reason = 'system error' + where + public_id in (select public_id from connections_to_close) +returning public_id; +` + checkIfNotActive = ` +select session_id, state + from session_state ss +where + (ss.state = 'canceling' or ss.state = 'terminated') + and ss.end_time is null + %s +; +` ) const ( @@ -328,7 +385,7 @@ values (?, ?, ?)` sessionCredentialDynamicBatchInsertReturning = ` - returning session_id, library_id, credential_id, credential_purpose + returning session_id, library_id, credential_id, credential_purpose; ` ) diff --git a/internal/session/repository_connection.go b/internal/session/repository_connection.go index 856b40d636..b42e7ad933 100644 --- a/internal/session/repository_connection.go +++ b/internal/session/repository_connection.go @@ -7,21 +7,150 @@ import ( "strings" "time" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/servers" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) -// deadWorkerConnCloseMinGrace is the minimum allowable setting for +// DeadWorkerConnCloseMinGrace is the minimum allowable setting for // the CloseConnectionsForDeadWorkers method. This is synced with // the default server liveness setting. -var DeadWorkerConnCloseMinGrace = int(servers.DefaultLiveness.Seconds()) +var DeadWorkerConnCloseMinGrace = servers.DefaultLiveness + +// ConnectionRepository is the session connection database repository. +type ConnectionRepository struct { + reader db.Reader + writer db.Writer + kms *kms.Kms + + // defaultLimit provides a default for limiting the number of results returned from the repo + defaultLimit int + + // workerStateDelay is used by queries to account for a delay in state propagation between + // worker and controller + workerStateDelay time.Duration + + // deadWorkerConnCloseMinGrace is the minimum allowable setting for + // the CloseConnectionsForDeadWorkers method. This defaults to + // the default server liveness setting. + deadWorkerConnCloseMinGrace time.Duration +} + +// NewConnectionRepository creates a new session Connection Repository. Supports the options: WithLimit +// which sets a default limit on results returned by repo operations. +func NewConnectionRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*ConnectionRepository, error) { + const op = "sessionConnection.NewRepository" + if r == nil { + return nil, errors.New(ctx, errors.InvalidParameter, op, "nil reader") + } + if w == nil { + return nil, errors.New(ctx, errors.InvalidParameter, op, "nil writer") + } + if kms == nil { + return nil, errors.New(ctx, errors.InvalidParameter, op, "nil kms") + } + opts := getOpts(opt...) + if opts.withLimit == 0 { + // zero signals the boundary defaults should be used. + opts.withLimit = db.DefaultLimit + } + + return &ConnectionRepository{ + reader: r, + writer: w, + kms: kms, + defaultLimit: opts.withLimit, + workerStateDelay: opts.withWorkerStateDelay, + deadWorkerConnCloseMinGrace: opts.withDeadWorkerConnCloseMinGrace, + }, nil +} + +// list will return a listing of resources and honor the WithLimit option or the +// repo defaultLimit. Supports WithOrder option. +func (r *ConnectionRepository) list(ctx context.Context, resources interface{}, where string, args []interface{}, opt ...Option) error { + const op = "session.(ConnectionRepository).list" + opts := getOpts(opt...) + limit := r.defaultLimit + var dbOpts []db.Option + if opts.withLimit != 0 { + // non-zero signals an override of the default limit for the repo. + limit = opts.withLimit + } + dbOpts = append(dbOpts, db.WithLimit(limit)) + switch opts.withOrderByCreateTime { + case db.AscendingOrderBy: + dbOpts = append(dbOpts, db.WithOrder("create_time asc")) + case db.DescendingOrderBy: + dbOpts = append(dbOpts, db.WithOrder("create_time")) + } + if err := r.reader.SearchWhere(ctx, resources, where, args, dbOpts...); err != nil { + return errors.Wrap(ctx, err, op) + } + return nil +} + +// AuthorizeConnection will check to see if a connection is allowed. Currently, +// that authorization checks: +// * the hasn't expired based on the session.Expiration +// * number of connections already created is less than session.ConnectionLimit +// If authorization is success, it creates/stores a new connection in the repo +// and returns it, along with its states. If the authorization fails, it +// an error with Code InvalidSessionState. +func (r *ConnectionRepository) AuthorizeConnection(ctx context.Context, sessionId, workerId string) (*Connection, []*ConnectionState, error) { + const op = "session.(ConnectionRepository).AuthorizeConnection" + if sessionId == "" { + return nil, nil, errors.Wrap(ctx, status.Error(codes.FailedPrecondition, "missing session id"), op, errors.WithCode(errors.InvalidParameter)) + } + connectionId, err := newConnectionId() + if err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + + connection := AllocConnection() + connection.PublicId = connectionId + var connectionStates []*ConnectionState + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + rowsAffected, err := w.Exec(ctx, authorizeConnectionCte, []interface{}{ + sql.Named("session_id", sessionId), + sql.Named("public_id", connectionId), + sql.Named("worker_id", workerId), + }) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to authorize connection %s", sessionId))) + } + if rowsAffected == 0 { + return errors.Wrap(ctx, status.Errorf(codes.PermissionDenied, "session %s is not authorized (not active, expired or connection limit reached)", sessionId), op, errors.WithCode(errors.InvalidSessionState)) + } + if err := reader.LookupById(ctx, &connection); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for session %s", sessionId))) + } + connectionStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc")) + if err != nil { + return errors.Wrap(ctx, err, op) + } + return nil + }, + ) + if err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + + return &connection, connectionStates, nil +} // LookupConnection will look up a connection in the repository and return the connection // with its states. If the connection is not found, it will return nil, nil, nil. // No options are currently supported. -func (r *Repository) LookupConnection(ctx context.Context, connectionId string, _ ...Option) (*Connection, []*ConnectionState, error) { - const op = "session.(Repository).LookupConnection" +func (r *ConnectionRepository) LookupConnection(ctx context.Context, connectionId string, _ ...Option) (*Connection, []*ConnectionState, error) { + const op = "session.(ConnectionRepository).LookupConnection" if connectionId == "" { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing connectionId id") } @@ -54,8 +183,8 @@ func (r *Repository) LookupConnection(ctx context.Context, connectionId string, // ListConnectionsBySessionId will list connections by session ID. Supports the // WithLimit and WithOrder options. -func (r *Repository) ListConnectionsBySessionId(ctx context.Context, sessionId string, opt ...Option) ([]*Connection, error) { - const op = "session.(Repository).ListConnectionsBySessionId" +func (r *ConnectionRepository) ListConnectionsBySessionId(ctx context.Context, sessionId string, opt ...Option) ([]*Connection, error) { + const op = "session.(ConnectionRepository).ListConnectionsBySessionId" if sessionId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "no session ID supplied") } @@ -67,9 +196,130 @@ func (r *Repository) ListConnectionsBySessionId(ctx context.Context, sessionId s return connections, nil } +// ConnectConnection updates a connection in the repo with a state of "connected". +func (r *ConnectionRepository) ConnectConnection(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) { + const op = "session.(ConnectionRepository).ConnectConnection" + // ConnectWith.validate will check all the fields... + if err := c.validate(); err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + var connection Connection + var connectionStates []*ConnectionState + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + connection = AllocConnection() + connection.PublicId = c.ConnectionId + connection.ClientTcpAddress = c.ClientTcpAddress + connection.ClientTcpPort = c.ClientTcpPort + connection.EndpointTcpAddress = c.EndpointTcpAddress + connection.EndpointTcpPort = c.EndpointTcpPort + connection.UserClientIp = c.UserClientIp + fieldMask := []string{ + "ClientTcpAddress", + "ClientTcpPort", + "EndpointTcpAddress", + "EndpointTcpPort", + "UserClientIp", + } + rowsUpdated, err := w.Update(ctx, &connection, fieldMask, nil) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if err == nil && rowsUpdated > 1 { + // return err, which will result in a rollback of the update + return errors.New(ctx, errors.MultipleRecords, op, "more than 1 resource would have been updated") + } + newState, err := NewConnectionState(connection.PublicId, StatusConnected) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if err := w.Create(ctx, newState); err != nil { + return errors.Wrap(ctx, err, op) + } + connectionStates, err = fetchConnectionStates(ctx, reader, c.ConnectionId, db.WithOrder("start_time desc")) + if err != nil { + return errors.Wrap(ctx, err, op) + } + return nil + }, + ) + if err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + return &connection, connectionStates, nil +} + +// closeConnectionResp is just a wrapper for the response from CloseConnections. +// It wraps the connection and its states for each connection closed. +type closeConnectionResp struct { + Connection *Connection + ConnectionStates []*ConnectionState +} + +// closeConnections set's a connection's state to "closed" in the repo. It's +// called by a worker after it's closed a connection between the client and the +// endpoint +func (r *ConnectionRepository) closeConnections(ctx context.Context, closeWith []CloseWith, _ ...Option) ([]closeConnectionResp, error) { + const op = "session.(ConnectionRepository).closeConnections" + if len(closeWith) == 0 { + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing connections") + } + for _, cw := range closeWith { + if err := cw.validate(); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("%s was invalid", cw.ConnectionId))) + } + } + var resp []closeConnectionResp + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + for _, cw := range closeWith { + updateConnection := AllocConnection() + updateConnection.PublicId = cw.ConnectionId + updateConnection.BytesUp = cw.BytesUp + updateConnection.BytesDown = cw.BytesDown + updateConnection.ClosedReason = cw.ClosedReason.String() + // updating the ClosedReason will trigger an insert into the + // session_connection_state with a state of closed. + rowsUpdated, err := w.Update( + ctx, + &updateConnection, + []string{"BytesUp", "BytesDown", "ClosedReason"}, + nil, + ) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to update connection %s", cw.ConnectionId))) + } + if rowsUpdated != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("%d would have been updated for connection %s", rowsUpdated, cw.ConnectionId)) + } + states, err := fetchConnectionStates(ctx, reader, cw.ConnectionId, db.WithOrder("start_time desc")) + if err != nil { + return errors.Wrap(ctx, err, op) + } + resp = append(resp, closeConnectionResp{ + Connection: &updateConnection, + ConnectionStates: states, + }) + + } + return nil + }, + ) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + return resp, nil +} + // DeleteConnection will delete a connection from the repository. -func (r *Repository) DeleteConnection(ctx context.Context, publicId string, _ ...Option) (int, error) { - const op = "session.(Repository).DeleteConnection" +func (r *ConnectionRepository) DeleteConnection(ctx context.Context, publicId string, _ ...Option) (int, error) { + const op = "session.(ConnectionRepository).DeleteConnection" if publicId == "" { return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing public id") } @@ -107,53 +357,6 @@ func (r *Repository) DeleteConnection(ctx context.Context, publicId string, _ .. return rowsDeleted, nil } -// CloseDeadConnectionsForWorker will run closeDeadConnectionsCte to look for -// connections that should be marked closed because they are no longer claimed -// by a server. -// -// The foundConns input should be the currently-claimed connections; the CTE -// uses a NOT IN clause to ensure these are excluded. It is not an error for -// this to be empty as the worker could claim no connections; in that case all -// connections will immediately transition to closed. -func (r *Repository) CloseDeadConnectionsForWorker(ctx context.Context, serverId string, foundConns []string) (int, error) { - const op = "session.(Repository).CloseDeadConnectionsForWorker" - if serverId == "" { - return db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing server id") - } - - args := make([]interface{}, 0, len(foundConns)+1) - args = append(args, serverId) - - var publicIdStr string - if len(foundConns) > 0 { - publicIdStr = `public_id not in (%s) and` - params := make([]string, len(foundConns)) - for i, connId := range foundConns { - params[i] = fmt.Sprintf("@%d", i+2) // Add one for server ID, and offsets start at 1 - args = append(args, sql.Named(fmt.Sprintf("%d", i+2), connId)) - } - publicIdStr = fmt.Sprintf(publicIdStr, strings.Join(params, ",")) - } - var rowsAffected int - _, err := r.writer.DoTx( - ctx, - db.StdRetryCnt, - db.ExpBackoff{}, - func(reader db.Reader, w db.Writer) error { - var err error - rowsAffected, err = w.Exec(ctx, fmt.Sprintf(closeDeadConnectionsCte, publicIdStr), args) - if err != nil { - return errors.Wrap(ctx, err, op) - } - return nil - }, - ) - if err != nil { - return db.NoRowsAffected, errors.Wrap(ctx, err, op) - } - return rowsAffected, nil -} - type CloseConnectionsForDeadWorkersResult struct { ServerId string LastUpdateTime time.Time @@ -166,20 +369,23 @@ type CloseConnectionsForDeadWorkersResult struct { // sending status updates to the controller(s). // // The only input to the method is the grace period, in seconds. -func (r *Repository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePeriod int) ([]CloseConnectionsForDeadWorkersResult, error) { - const op = "session.(Repository).CloseConnectionsForDeadWorkers" - if gracePeriod < DeadWorkerConnCloseMinGrace { +func (r *ConnectionRepository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePeriod time.Duration) ([]CloseConnectionsForDeadWorkersResult, error) { + const op = "session.(ConnectionRepository).CloseConnectionsForDeadWorkers" + if gracePeriod < r.deadWorkerConnCloseMinGrace { return nil, errors.New(ctx, - errors.InvalidParameter, op, fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace)) + errors.InvalidParameter, op, fmt.Sprintf("gracePeriod must be at least %s", r.deadWorkerConnCloseMinGrace)) } + args := []interface{}{ + sql.Named("grace_period_seconds", gracePeriod.Seconds()), + } results := make([]CloseConnectionsForDeadWorkersResult, 0) _, err := r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { - rows, err := w.Query(ctx, closeConnectionsForDeadServersCte, []interface{}{gracePeriod}) + rows, err := w.Query(ctx, closeConnectionsForDeadServersCte, args) if err != nil { return errors.Wrap(ctx, err, op) } @@ -204,67 +410,52 @@ func (r *Repository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePe return results, nil } -// ShouldCloseConnectionsOnWorker will run shouldCloseConnectionsCte to look -// for connections that the worker should close because they are currently -// reporting them as open incorrectly. -// -// The foundConns input here is used to filter closed connection states. This -// is further filtered against the filterSessions input, which is expected to -// be a set of sessions we've already submitted close requests for, so adding -// them again would be redundant. -// -// The returned map[string][]string is indexed by session ID. -func (r *Repository) ShouldCloseConnectionsOnWorker(ctx context.Context, foundConns, filterSessions []string) (map[string][]string, error) { - const op = "session.(Repository).ShouldCloseConnectionsOnWorker" - if len(foundConns) < 1 { - return nil, nil // nothing to do - } +// closeOrphanedConnections looks for connections that are still active, but where not reported by the worker. +func (r *ConnectionRepository) closeOrphanedConnections(ctx context.Context, workerId string, reportedConnections []string) ([]string, error) { + const op = "session.(ConnectionRepository).closeOrphanedConnections" - args := make([]interface{}, 0, len(foundConns)+len(filterSessions)) + var orphanedConns []string - // foundConns first - connsParams := make([]string, len(foundConns)) - for i, connId := range foundConns { - connsParams[i] = fmt.Sprintf("@%d", i+1) - args = append(args, sql.Named(fmt.Sprintf("%d", i+1), connId)) - } - connsStr := strings.Join(connsParams, ",") - - // then filterSessions - var sessionsStr string - if len(filterSessions) > 0 { - offset := len(foundConns) + 1 - sessionsParams := make([]string, len(filterSessions)) - for i, sessionId := range filterSessions { - sessionsParams[i] = fmt.Sprintf("@%d", i+offset) - args = append(args, sql.Named(fmt.Sprintf("%d", i+offset), sessionId)) - } + args := make([]interface{}, 0, len(reportedConnections)+2) + args = append(args, sql.Named("server_id", workerId)) + args = append(args, sql.Named("worker_state_delay_seconds", r.workerStateDelay.Seconds())) - const sessionIdFmtStr = `and session_id not in (%s)` - sessionsStr = fmt.Sprintf(sessionIdFmtStr, strings.Join(sessionsParams, ",")) + var notInClause string + if len(reportedConnections) > 0 { + notInClause = `and public_id not in (%s)` + params := make([]string, len(reportedConnections)) + for i, connId := range reportedConnections { + params[i] = fmt.Sprintf("@%d", i) + args = append(args, sql.Named(fmt.Sprintf("%d", i), connId)) + } + notInClause = fmt.Sprintf(notInClause, strings.Join(params, ",")) } - rows, err := r.reader.Query( + _, err := r.writer.DoTx( ctx, - fmt.Sprintf(shouldCloseConnectionsCte, connsStr, sessionsStr), - args, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + rows, err := r.reader.Query(ctx, fmt.Sprintf(orphanedConnectionsCte, notInClause), args) + if err != nil { + return errors.Wrap(ctx, err, op) + } + defer rows.Close() + + for rows.Next() { + var connectionId string + if err := rows.Scan(&connectionId); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed")) + } + orphanedConns = append(orphanedConns, connectionId) + } + return nil + }, ) if err != nil { - return nil, errors.Wrap(ctx, err, op) + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error comparing state")) } - defer rows.Close() - - result := make(map[string][]string) - for rows.Next() { - var connectionId, sessionId string - if err := rows.Scan(&connectionId, &sessionId); err != nil { - return nil, errors.Wrap(ctx, err, op) - } - - result[sessionId] = append(result[sessionId], connectionId) - } - - return result, nil + return orphanedConns, nil } func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) { diff --git a/internal/session/repository_connection_test.go b/internal/session/repository_connection_test.go index 7e5753d353..764c9a831d 100644 --- a/internal/session/repository_connection_test.go +++ b/internal/session/repository_connection_test.go @@ -12,13 +12,13 @@ import ( "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/servers" - "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRepository_ListConnection(t *testing.T) { t.Parallel() + ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") const testLimit = 10 wrapper := db.TestWrapper(t) @@ -26,6 +26,7 @@ func TestRepository_ListConnection(t *testing.T) { rw := db.New(conn) kms := kms.TestKms(t, conn, wrapper) repo, err := NewRepository(rw, rw, kms, WithLimit(testLimit)) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms, WithLimit(testLimit)) require.NoError(t, err) session := TestDefaultSession(t, conn, wrapper, iamRepo) @@ -96,7 +97,7 @@ func TestRepository_ListConnection(t *testing.T) { testConnections = append(testConnections, c) } assert.Equal(tt.createCnt, len(testConnections)) - got, err := repo.ListConnectionsBySessionId(context.Background(), tt.args.searchForSessionId, tt.args.opt...) + got, err := connRepo.ListConnectionsBySessionId(context.Background(), tt.args.searchForSessionId, tt.args.opt...) if tt.wantErr { require.Error(err) return @@ -119,7 +120,7 @@ func TestRepository_ListConnection(t *testing.T) { "127.0.0.1", ) } - got, err := repo.ListConnectionsBySessionId(context.Background(), session.PublicId, WithOrderByCreateTime(db.AscendingOrderBy)) + got, err := connRepo.ListConnectionsBySessionId(context.Background(), session.PublicId, WithOrderByCreateTime(db.AscendingOrderBy)) require.NoError(err) assert.Equal(wantCnt, len(got)) @@ -131,14 +132,140 @@ func TestRepository_ListConnection(t *testing.T) { }) } -func TestRepository_DeleteConnection(t *testing.T) { +func TestRepository_ConnectConnection(t *testing.T) { t.Parallel() + ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) kms := kms.TestKms(t, conn, wrapper) repo, err := NewRepository(rw, rw, kms) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) + require.NoError(t, err) + + setupFn := func() ConnectWith { + s := TestDefaultSession(t, conn, wrapper, iamRepo) + srv := TestWorker(t, conn, wrapper) + tofu := TestTofu(t) + _, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) + require.NoError(t, err) + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") + return ConnectWith{ + ConnectionId: c.PublicId, + ClientTcpAddress: "127.0.0.1", + ClientTcpPort: 22, + EndpointTcpAddress: "127.0.0.1", + EndpointTcpPort: 2222, + UserClientIp: "127.0.0.1", + } + } + tests := []struct { + name string + connectWith ConnectWith + wantErr bool + wantIsError errors.Code + }{ + { + name: "valid", + connectWith: setupFn(), + }, + { + name: "empty-SessionId", + connectWith: func() ConnectWith { + cw := setupFn() + cw.ConnectionId = "" + return cw + }(), + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "empty-ClientTcpAddress", + connectWith: func() ConnectWith { + cw := setupFn() + cw.ClientTcpAddress = "" + return cw + }(), + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "empty-ClientTcpPort", + connectWith: func() ConnectWith { + cw := setupFn() + cw.ClientTcpPort = 0 + return cw + }(), + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "empty-EndpointTcpAddress", + connectWith: func() ConnectWith { + cw := setupFn() + cw.EndpointTcpAddress = "" + return cw + }(), + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "empty-EndpointTcpPort", + connectWith: func() ConnectWith { + cw := setupFn() + cw.EndpointTcpPort = 0 + return cw + }(), + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "empty-UserClientIp", + connectWith: func() ConnectWith { + cw := setupFn() + cw.UserClientIp = "" + return cw + }(), + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + c, cs, err := connRepo.ConnectConnection(context.Background(), tt.connectWith) + if tt.wantErr { + require.Error(err) + assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error()) + return + } + require.NoError(err) + require.NotNil(c) + require.NotNil(cs) + assert.Equal(StatusConnected, cs[0].Status) + gotConn, _, err := connRepo.LookupConnection(context.Background(), c.PublicId) + require.NoError(err) + assert.Equal(tt.connectWith.ClientTcpAddress, gotConn.ClientTcpAddress) + assert.Equal(tt.connectWith.ClientTcpPort, gotConn.ClientTcpPort) + assert.Equal(tt.connectWith.ClientTcpAddress, gotConn.ClientTcpAddress) + assert.Equal(tt.connectWith.EndpointTcpAddress, gotConn.EndpointTcpAddress) + assert.Equal(tt.connectWith.EndpointTcpPort, gotConn.EndpointTcpPort) + assert.Equal(tt.connectWith.UserClientIp, gotConn.UserClientIp) + }) + } +} + +func TestRepository_DeleteConnection(t *testing.T) { + t.Parallel() + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) require.NoError(t, err) session := TestDefaultSession(t, conn, wrapper, iamRepo) @@ -171,7 +298,7 @@ func TestRepository_DeleteConnection(t *testing.T) { }, wantRowsDeleted: 0, wantErr: true, - wantErrMsg: "session.(Repository).DeleteConnection: missing public id: parameter violation: error #100", + wantErrMsg: "session.(ConnectionRepository).DeleteConnection: missing public id: parameter violation: error #100", }, { name: "not-found", @@ -192,7 +319,7 @@ func TestRepository_DeleteConnection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert := assert.New(t) - deletedRows, err := repo.DeleteConnection(context.Background(), tt.args.connection.PublicId, tt.args.opt...) + deletedRows, err := connRepo.DeleteConnection(context.Background(), tt.args.connection.PublicId, tt.args.opt...) if tt.wantErr { assert.Error(err) assert.Equal(0, deletedRows) @@ -204,7 +331,7 @@ func TestRepository_DeleteConnection(t *testing.T) { } assert.NoError(err) assert.Equal(tt.wantRowsDeleted, deletedRows) - found, _, err := repo.LookupConnection(context.Background(), tt.args.connection.PublicId) + found, _, err := connRepo.LookupConnection(context.Background(), tt.args.connection.PublicId) assert.NoError(err) assert.Nil(found) @@ -214,8 +341,9 @@ func TestRepository_DeleteConnection(t *testing.T) { } } -func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { +func TestRepository_orphanedConnections(t *testing.T) { t.Parallel() + ctx := context.Background() require, assert := require.New(t), assert.New(t) conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) @@ -223,8 +351,8 @@ func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { iamRepo := iam.TestRepo(t, conn, wrapper) kms := kms.TestKms(t, conn, wrapper) repo, err := NewRepository(rw, rw, kms) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms, WithWorkerStateDelay(0)) require.NoError(err) - ctx := context.Background() numConns := 12 // Create two "workers". One will remain untouched while the other "goes @@ -234,6 +362,7 @@ func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { // Create a few sessions on each, activate, and authorize a connection var connIds []string + var worker1ConnIds []string var worker2ConnIds []string for i := 0; i < numConns; i++ { serverId := worker1.PrivateId @@ -243,13 +372,15 @@ func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true))) sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo")) require.NoError(err) - c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + c, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) require.NoError(err) require.Len(cs, 1) require.Equal(StatusAuthorized, cs[0].Status) connIds = append(connIds, c.GetPublicId()) if i%2 == 0 { worker2ConnIds = append(worker2ConnIds, c.GetPublicId()) + } else { + worker1ConnIds = append(worker1ConnIds, c.GetPublicId()) } } @@ -257,7 +388,7 @@ func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { // This is just to ensure we have a spread when we test it out. for i, connId := range connIds { if i%2 == 0 { - _, cs, err := repo.ConnectConnection(ctx, ConnectWith{ + _, cs, err := connRepo.ConnectConnection(ctx, ConnectWith{ ConnectionId: connId, ClientTcpAddress: "127.0.0.1", ClientTcpPort: 22, @@ -281,63 +412,33 @@ func TestRepository_CloseDeadConnectionsOnWorker(t *testing.T) { } } - // There is a 10 second delay to account for time for the connections to - // transition - time.Sleep(15 * time.Second) - // Now, advertise only some of the connection IDs for worker 2. After, // all connection IDs for worker 1 should be showing as non-closed, and // the ones for worker 2 not advertised should be closed. shouldStayOpen := worker2ConnIds[0:2] - count, err := repo.CloseDeadConnectionsForWorker(ctx, worker2.GetPrivateId(), shouldStayOpen) + found, err := connRepo.closeOrphanedConnections(ctx, worker2.GetPrivateId(), shouldStayOpen) require.NoError(err) - assert.Equal(4, count) + fmt.Printf("shouldstate: %v\nfound: %v\n", shouldStayOpen, found) + require.Equal(4, len(found)) // For the ones we didn't specify, we expect those to now be closed. We // expect all others to be open. - shouldBeClosed := worker2ConnIds[2:] - var conns []*Connection - require.NoError(repo.list(ctx, &conns, "", nil)) - for _, conn := range conns { - _, states, err := repo.LookupConnection(ctx, conn.PublicId) - require.NoError(err) - var foundClosed bool - for _, state := range states { - if state.Status == StatusClosed { - foundClosed = true - break - } - } - assert.True(foundClosed == strutil.StrListContains(shouldBeClosed, conn.PublicId)) - } + shouldBeFound := worker2ConnIds[2:] + assert.ElementsMatch(found, shouldBeFound) // Now, advertise none of the connection IDs for worker 2. This is mainly to // test that handling the case where we do not include IDs works properly as // it changes the where clause. - count, err = repo.CloseDeadConnectionsForWorker(ctx, worker1.GetPrivateId(), nil) + found, err = connRepo.closeOrphanedConnections(ctx, worker1.GetPrivateId(), nil) require.NoError(err) - assert.Equal(6, count) - - // We now expect all but those blessed few to be closed - conns = nil - require.NoError(repo.list(ctx, &conns, "", nil)) - for _, conn := range conns { - _, states, err := repo.LookupConnection(ctx, conn.PublicId) - require.NoError(err) - var foundClosed bool - for _, state := range states { - if state.Status == StatusClosed { - foundClosed = true - break - } - } - assert.True(foundClosed != strutil.StrListContains(shouldStayOpen, conn.PublicId)) - } + assert.Equal(6, len(found)) + assert.ElementsMatch(found, worker1ConnIds) } func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { t.Parallel() + ctx := context.Background() require := require.New(t) conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) @@ -345,10 +446,11 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { iamRepo := iam.TestRepo(t, conn, wrapper) kms := kms.TestKms(t, conn, wrapper) repo, err := NewRepository(rw, rw, kms) + deadWorkerConnCloseMinGrace := 1 * time.Second + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms, WithDeadWorkerConnCloseMinGrace(deadWorkerConnCloseMinGrace)) require.NoError(err) serversRepo, err := servers.NewRepository(rw, rw, kms) require.NoError(err) - ctx := context.Background() // connection count = 6 * states(authorized, connected, closed = 3) * servers_with_open_connections(3) numConns := 54 @@ -377,7 +479,7 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true))) sess, _, err = repo.ActivateSession(ctx, sess.GetPublicId(), sess.Version, serverId, "worker", []byte("foo")) require.NoError(err) - c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) + c, cs, err := connRepo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) require.NoError(err) require.Len(cs, 1) require.Equal(StatusAuthorized, cs[0].Status) @@ -401,7 +503,7 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { return s }() { if i%3 == 0 { - _, cs, err := repo.ConnectConnection(ctx, ConnectWith{ + _, cs, err := connRepo.ConnectConnection(ctx, ConnectWith{ ConnectionId: connId, ClientTcpAddress: "127.0.0.1", ClientTcpPort: 22, @@ -423,7 +525,7 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { require.True(foundAuthorized) require.True(foundConnected) } else if i%3 == 1 { - resp, err := repo.CloseConnections(ctx, []CloseWith{ + resp, err := connRepo.closeConnections(ctx, []CloseWith{ { ConnectionId: connId, ClosedReason: ConnectionCanceled, @@ -447,10 +549,6 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { } } - // There is a 15 second delay to account for time for the connections to - // transition - time.Sleep(15 * time.Second) - // updateServer is a helper for updating the update time for our // servers. The controller is read back so that we can reference // the most up-to-date fields. @@ -496,7 +594,7 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { expected = StatusAuthorized } - _, states, err := repo.LookupConnection(ctx, connId) + _, states, err := connRepo.LookupConnection(ctx, connId) require.NoError(err) require.Equal(expected, states[0].Status, "expected latest status for %q (index %d) to be %v", connId, i, expected) } @@ -515,11 +613,11 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { // Now try some scenarios. { // First, test the error/validation case. - result, err := repo.CloseConnectionsForDeadWorkers(ctx, 0) + result, err := connRepo.CloseConnectionsForDeadWorkers(ctx, -1) require.Equal(err, errors.E(ctx, errors.WithCode(errors.InvalidParameter), - errors.WithOp("session.(Repository).CloseConnectionsForDeadWorkers"), - errors.WithMsg(fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace)), + errors.WithOp("session.(ConnectionRepository).CloseConnectionsForDeadWorkers"), + errors.WithMsg(fmt.Sprintf("gracePeriod must be at least %s", deadWorkerConnCloseMinGrace)), )) require.Nil(result) } @@ -531,7 +629,7 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { worker3 = updateServer(t, worker3) updateServer(t, worker4) // no re-assignment here because we never reference the server again - result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + result, err := connRepo.CloseConnectionsForDeadWorkers(ctx, deadWorkerConnCloseMinGrace) require.NoError(err) require.Empty(result) // Expect appropriate split connection state on worker1 @@ -546,12 +644,12 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { // Now try a zero case - similar to the basis, but only in that no results // are expected to be returned for workers with no connections, even if // they are dead. Here, the server with no connections is worker #4. - time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + time.Sleep(deadWorkerConnCloseMinGrace) worker1 = updateServer(t, worker1) worker2 = updateServer(t, worker2) worker3 = updateServer(t, worker3) - result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + result, err := connRepo.CloseConnectionsForDeadWorkers(ctx, deadWorkerConnCloseMinGrace) require.NoError(err) require.Empty(result) // Expect appropriate split connection state on worker1 @@ -565,11 +663,11 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { { // The first induction is letting the first worker "die" by not updating it // too. All of its authorized and connected connections should be dead. - time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + time.Sleep(deadWorkerConnCloseMinGrace) worker2 = updateServer(t, worker2) worker3 = updateServer(t, worker3) - result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + result, err := connRepo.CloseConnectionsForDeadWorkers(ctx, deadWorkerConnCloseMinGrace) require.NoError(err) // Assert that we have one result with the appropriate ID and // number of connections closed. Due to how things are @@ -592,9 +690,9 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { // The final case is having the other two workers die. After // this, we should have all connections closed with the // appropriate message from the next two servers acted on. - time.Sleep(time.Second * time.Duration(DeadWorkerConnCloseMinGrace)) + time.Sleep(deadWorkerConnCloseMinGrace) - result, err := repo.CloseConnectionsForDeadWorkers(ctx, DeadWorkerConnCloseMinGrace) + result, err := connRepo.CloseConnectionsForDeadWorkers(ctx, deadWorkerConnCloseMinGrace) require.NoError(err) // Assert that we have one result with the appropriate ID and number of connections closed. require.Equal([]CloseConnectionsForDeadWorkersResult{ @@ -618,159 +716,84 @@ func TestRepository_CloseConnectionsForDeadWorkers(t *testing.T) { } } -func TestRepository_ShouldCloseConnectionsOnWorker(t *testing.T) { +func TestRepository_CloseConnections(t *testing.T) { t.Parallel() - require := require.New(t) + ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) kms := kms.TestKms(t, conn, wrapper) repo, err := NewRepository(rw, rw, kms) - require.NoError(err) - ctx := context.Background() - numConns := 12 - - // Create a worker, we only need one here as our query is dependent - // on connection and not worker. - worker1 := TestWorker(t, conn, wrapper, WithServerId("worker1")) - - // Create a few sessions on each, activate, and authorize a connection - var connIds []string - sessionConnIds := make(map[string][]string) - for i := 0; i < numConns; i++ { - serverId := worker1.PrivateId - sess := TestDefaultSession(t, conn, wrapper, iamRepo, WithServerId(serverId), WithDbOpts(db.WithSkipVetForWrite(true))) - sessionId := sess.GetPublicId() - sess, _, err = repo.ActivateSession(ctx, sessionId, sess.Version, serverId, "worker", []byte("foo")) - require.NoError(err) - c, cs, _, err := repo.AuthorizeConnection(ctx, sess.GetPublicId(), serverId) - require.NoError(err) - require.Len(cs, 1) - require.Equal(StatusAuthorized, cs[0].Status) - connId := c.GetPublicId() - connIds = append(connIds, connId) - sessionConnIds[sessionId] = append(sessionConnIds[sessionId], connId) - } + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) + require.NoError(t, err) - // Mark half of the connections connected, close the other half. - for i, connId := range connIds { - if i%2 == 0 { - _, cs, err := repo.ConnectConnection(ctx, ConnectWith{ - ConnectionId: connId, - ClientTcpAddress: "127.0.0.1", - ClientTcpPort: 22, - EndpointTcpAddress: "127.0.0.1", - EndpointTcpPort: 22, - UserClientIp: "127.0.0.1", - }) - require.NoError(err) - require.Len(cs, 2) - var foundAuthorized, foundConnected bool - for _, status := range cs { - if status.Status == StatusAuthorized { - foundAuthorized = true - } - if status.Status == StatusConnected { - foundConnected = true - } - } - require.True(foundAuthorized) - require.True(foundConnected) - } else { - resp, err := repo.CloseConnections(ctx, []CloseWith{ - { - ConnectionId: connId, - ClosedReason: ConnectionCanceled, - }, + setupFn := func(cnt int) []CloseWith { + s := TestDefaultSession(t, conn, wrapper, iamRepo) + srv := TestWorker(t, conn, wrapper) + tofu := TestTofu(t) + s, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) + require.NoError(t, err) + cw := make([]CloseWith, 0, cnt) + for i := 0; i < cnt; i++ { + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") + require.NoError(t, err) + cw = append(cw, CloseWith{ + ConnectionId: c.PublicId, + BytesUp: 1, + BytesDown: 2, + ClosedReason: ConnectionClosedByUser, }) - require.NoError(err) - require.Len(resp, 1) - cs := resp[0].ConnectionStates - require.Len(cs, 2) - var foundAuthorized, foundClosed bool - for _, status := range cs { - if status.Status == StatusAuthorized { - foundAuthorized = true - } - if status.Status == StatusClosed { - foundClosed = true - } - } - require.True(foundAuthorized) - require.True(foundClosed) } + return cw } - - // There is a 10 second delay to account for time for the connections to - // transition - time.Sleep(15 * time.Second) - - // Now we try some scenarios. - { - // First test an empty set. - result, err := repo.ShouldCloseConnectionsOnWorker(ctx, nil, nil) - require.NoError(err) - require.Zero(result, "should be empty when no connections are supplied") + tests := []struct { + name string + closeWith []CloseWith + reason TerminationReason + wantErr bool + wantIsError errors.Code + }{ + { + name: "valid", + closeWith: setupFn(2), + reason: ClosedByUser, + }, + { + name: "empty-closed-with", + closeWith: []CloseWith{}, + reason: ClosedByUser, + wantErr: true, + wantIsError: errors.InvalidParameter, + }, + { + name: "missing-ConnectionId", + closeWith: func() []CloseWith { + cw := setupFn(2) + cw[1].ConnectionId = "" + return cw + }(), + reason: ClosedByUser, + wantErr: true, + wantIsError: errors.InvalidParameter, + }, } - - { - // Here we pass in all of our connections without a filter on - // session. This should return half of the connections - the ones - // that we marked as closed. - // - // Create a copy of our session map with the sessions that have - // closed connections taken out. - expectedSessionConnIds := make(map[string][]string) - for sessionId, connIds := range sessionConnIds { - for _, connId := range connIds { - if testIsConnectionClosed(ctx, t, repo, connId) { - expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + resp, err := connRepo.closeConnections(context.Background(), tt.closeWith) + if tt.wantErr { + require.Error(err) + assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error()) + return } - } - - // Send query, use all connections w/o a filter on sessions. - actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, nil) - require.NoError(err) - require.Equal(expectedSessionConnIds, actualSessionConnIds) - } - - { - // Finally, add a session filter. We do this by just alternating - // the session IDs we want to filter on. - expectedSessionConnIds := make(map[string][]string) - var filterSessionIds []string - var filterSession bool - for sessionId, connIds := range sessionConnIds { - for _, connId := range connIds { - if testIsConnectionClosed(ctx, t, repo, connId) { - if !filterSession { - expectedSessionConnIds[sessionId] = append(expectedSessionConnIds[sessionId], connId) - } else { - filterSessionIds = append(filterSessionIds, sessionId) - } - - // Toggle filterSession here (instead of just outer session - // loop) so that we aren't just lining up on - // connected/disconnected connections. - filterSession = !filterSession - } + require.NoError(err) + assert.Equal(len(tt.closeWith), len(resp)) + for _, r := range resp { + require.NotNil(r.Connection) + require.NotNil(r.ConnectionStates) + assert.Equal(StatusClosed, r.ConnectionStates[0].Status) } - } - - // Send query with the session filter. - actualSessionConnIds, err := repo.ShouldCloseConnectionsOnWorker(ctx, connIds, filterSessionIds) - require.NoError(err) - require.Equal(expectedSessionConnIds, actualSessionConnIds) + }) } } - -func testIsConnectionClosed(ctx context.Context, t *testing.T, repo *Repository, connId string) bool { - require := require.New(t) - _, states, err := repo.LookupConnection(ctx, connId) - require.NoError(err) - // Use first state as this LookupConnections returns ordered by - // start time, descending. - return states[0].Status == StatusClosed -} diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 69b9d32028..c1d2f37a21 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -13,8 +13,6 @@ import ( "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" wrapping "github.com/hashicorp/go-kms-wrapping" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) // CreateSession inserts into the repository and returns the new Session with @@ -138,7 +136,7 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. // with its states. Returned States are ordered by start time descending. If the // session is not found, it will return nil, nil, nil. No options are currently // supported. -func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...Option) (*Session, *ConnectionAuthzSummary, error) { +func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...Option) (*Session, *AuthzSummary, error) { const op = "session.(Repository).LookupSession" if sessionId == "" { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing session id") @@ -377,46 +375,23 @@ func (r *Repository) TerminateCompletedSessions(ctx context.Context) (int, error return rowsAffected, nil } -// AuthorizeConnection will check to see if a connection is allowed. Currently, -// that authorization checks: -// * the hasn't expired based on the session.Expiration -// * number of connections already created is less than session.ConnectionLimit -// If authorization is success, it creates/stores a new connection in the repo -// and returns it, along with its states. If the authorization fails, it -// an error with Code InvalidSessionState. -func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId, workerId string) (*Connection, []*ConnectionState, *ConnectionAuthzSummary, error) { - const op = "session.(Repository).AuthorizeConnection" - if sessionId == "" { - return nil, nil, nil, errors.Wrap(ctx, status.Error(codes.FailedPrecondition, "missing session id"), op, errors.WithCode(errors.InvalidParameter)) - } - connectionId, err := newConnectionId() - if err != nil { - return nil, nil, nil, errors.Wrap(ctx, err, op) - } +// terminateSessionIfPossible is called on connection close and will attempt to close the connection's +// session if the following conditions are met: +// * sessions that have exhausted their connection limit and all their connections are closed. +// * sessions that are expired and all their connections are closed. +// * sessions that are canceling and all their connections are closed +func (r *Repository) terminateSessionIfPossible(ctx context.Context, sessionId string) (int, error) { + const op = "session.(Repository).terminateSessionIfPossible" + rowsAffected := 0 - connection := AllocConnection() - connection.PublicId = connectionId - var connectionStates []*ConnectionState - _, err = r.writer.DoTx( + _, err := r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { - rowsAffected, err := w.Exec(ctx, authorizeConnectionCte, []interface{}{ - sql.Named("session_id", sessionId), - sql.Named("public_id", connectionId), - sql.Named("worker_id", workerId), - }) - if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to authorize connection %s", sessionId))) - } - if rowsAffected == 0 { - return errors.Wrap(ctx, status.Errorf(codes.PermissionDenied, "session %s is not authorized (not active, expired or connection limit reached)", sessionId), op, errors.WithCode(errors.InvalidSessionState)) - } - if err := reader.LookupById(ctx, &connection); err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for session %s", sessionId))) - } - connectionStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc")) + var err error + rowsAffected, err = w.Exec(ctx, terminateSessionIfPossible, + []interface{}{sql.Named("public_id", sessionId)}) if err != nil { return errors.Wrap(ctx, err, op) } @@ -424,22 +399,18 @@ func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId, workerI }, ) if err != nil { - return nil, nil, nil, errors.Wrap(ctx, err, op) - } - authzSummary, err := r.sessionAuthzSummary(ctx, connection.SessionId) - if err != nil { - return nil, nil, nil, errors.Wrap(ctx, err, op) + return db.NoRowsAffected, errors.Wrap(ctx, err, op) } - return &connection, connectionStates, authzSummary, nil + return rowsAffected, nil } -type ConnectionAuthzSummary struct { +type AuthzSummary struct { ExpirationTime *timestamp.Timestamp ConnectionLimit int32 CurrentConnectionCount uint32 } -func (r *Repository) sessionAuthzSummary(ctx context.Context, sessionId string) (*ConnectionAuthzSummary, error) { +func (r *Repository) sessionAuthzSummary(ctx context.Context, sessionId string) (*AuthzSummary, error) { const op = "session.(Repository).sessionAuthzSummary" rows, err := r.reader.Query(ctx, remainingConnectionsCte, []interface{}{sql.Named("session_id", sessionId)}) if err != nil { @@ -447,12 +418,12 @@ func (r *Repository) sessionAuthzSummary(ctx context.Context, sessionId string) } defer rows.Close() - var info *ConnectionAuthzSummary + var info *AuthzSummary for rows.Next() { if info != nil { return nil, errors.New(ctx, errors.MultipleRecords, op, "query returned more than one row") } - info = &ConnectionAuthzSummary{} + info = &AuthzSummary{} if err := r.reader.ScanRows(ctx, rows, info); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed")) } @@ -460,127 +431,6 @@ func (r *Repository) sessionAuthzSummary(ctx context.Context, sessionId string) return info, nil } -// ConnectConnection updates a connection in the repo with a state of "connected". -func (r *Repository) ConnectConnection(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) { - const op = "session.(Repository).ConnectConnection" - // ConnectWith.validate will check all the fields... - if err := c.validate(); err != nil { - return nil, nil, errors.Wrap(ctx, err, op) - } - var connection Connection - var connectionStates []*ConnectionState - _, err := r.writer.DoTx( - ctx, - db.StdRetryCnt, - db.ExpBackoff{}, - func(reader db.Reader, w db.Writer) error { - connection = AllocConnection() - connection.PublicId = c.ConnectionId - connection.ClientTcpAddress = c.ClientTcpAddress - connection.ClientTcpPort = c.ClientTcpPort - connection.EndpointTcpAddress = c.EndpointTcpAddress - connection.EndpointTcpPort = c.EndpointTcpPort - connection.UserClientIp = c.UserClientIp - fieldMask := []string{ - "ClientTcpAddress", - "ClientTcpPort", - "EndpointTcpAddress", - "EndpointTcpPort", - "UserClientIp", - } - rowsUpdated, err := w.Update(ctx, &connection, fieldMask, nil) - if err != nil { - return errors.Wrap(ctx, err, op) - } - if err == nil && rowsUpdated > 1 { - // return err, which will result in a rollback of the update - return errors.New(ctx, errors.MultipleRecords, op, "more than 1 resource would have been updated") - } - newState, err := NewConnectionState(connection.PublicId, StatusConnected) - if err != nil { - return errors.Wrap(ctx, err, op) - } - if err := w.Create(ctx, newState); err != nil { - return errors.Wrap(ctx, err, op) - } - connectionStates, err = fetchConnectionStates(ctx, reader, c.ConnectionId, db.WithOrder("start_time desc")) - if err != nil { - return errors.Wrap(ctx, err, op) - } - return nil - }, - ) - if err != nil { - return nil, nil, errors.Wrap(ctx, err, op) - } - return &connection, connectionStates, nil -} - -// CloseConnectionResp is just a wrapper for the response from CloseConnections. -// It wraps the connection and its states for each connection closed. -type CloseConnectionResp struct { - Connection *Connection - ConnectionStates []*ConnectionState -} - -// CloseConnections set's a connection's state to "closed" in the repo. It's -// called by a worker after it's closed a connection between the client and the -// endpoint -func (r *Repository) CloseConnections(ctx context.Context, closeWith []CloseWith, _ ...Option) ([]CloseConnectionResp, error) { - const op = "session.(Repository).CloseConnections" - if len(closeWith) == 0 { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing connections") - } - for _, cw := range closeWith { - if err := cw.validate(); err != nil { - return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("%s was invalid", cw.ConnectionId))) - } - } - var resp []CloseConnectionResp - _, err := r.writer.DoTx( - ctx, - db.StdRetryCnt, - db.ExpBackoff{}, - func(reader db.Reader, w db.Writer) error { - for _, cw := range closeWith { - updateConnection := AllocConnection() - updateConnection.PublicId = cw.ConnectionId - updateConnection.BytesUp = cw.BytesUp - updateConnection.BytesDown = cw.BytesDown - updateConnection.ClosedReason = cw.ClosedReason.String() - // updating the ClosedReason will trigger an insert into the - // session_connection_state with a state of closed. - rowsUpdated, err := w.Update( - ctx, - &updateConnection, - []string{"BytesUp", "BytesDown", "ClosedReason"}, - nil, - ) - if err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to update connection %s", cw.ConnectionId))) - } - if rowsUpdated != 1 { - return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("%d would have been updated for connection %s", rowsUpdated, cw.ConnectionId)) - } - states, err := fetchConnectionStates(ctx, reader, cw.ConnectionId, db.WithOrder("start_time desc")) - if err != nil { - return errors.Wrap(ctx, err, op) - } - resp = append(resp, CloseConnectionResp{ - Connection: &updateConnection, - ConnectionStates: states, - }) - - } - return nil - }, - ) - if err != nil { - return nil, errors.Wrap(ctx, err, op) - } - return resp, nil -} - // ActivateSession will activate the session and is called by a worker after // authenticating the session. The session must be in a "pending" state to be // activated. States are ordered by start time descending. Returns an @@ -738,6 +588,59 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV return &updatedSession, returnedStates, nil } +// checkIfNoLongerActive checks the given sessions to see if they are in a +// non-active state, i.e. "canceling" or "terminated" +// It returns a []StateReport for each session that is not active, with its current status. +func (r *Repository) checkIfNoLongerActive(ctx context.Context, reportedSessions []string) ([]StateReport, error) { + const op = "session.(Repository).checkIfNotActive" + + notActive := make([]StateReport, 0, len(reportedSessions)) + args := make([]interface{}, 0, len(reportedSessions)) + var inClause string + + if len(reportedSessions) <= 0 { + return notActive, nil + } + + inClause = `and session_id in (%s)` + params := make([]string, len(reportedSessions)) + for i, sessId := range reportedSessions { + params[i] = fmt.Sprintf("@%d", i) + args = append(args, sql.Named(fmt.Sprintf("%d", i), sessId)) + } + inClause = fmt.Sprintf(inClause, strings.Join(params, ",")) + + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + rows, err := r.reader.Query(ctx, fmt.Sprintf(checkIfNotActive, inClause), args) + if err != nil { + return errors.Wrap(ctx, err, op) + } + defer rows.Close() + + for rows.Next() { + var sessionId string + var status Status + if err := rows.Scan(&sessionId, &status); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed")) + } + notActive = append(notActive, StateReport{ + SessionId: sessionId, + Status: status, + }) + } + return nil + }, + ) + if err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error checking if sessions are no longer active")) + } + return notActive, nil +} + func fetchStates(ctx context.Context, r db.Reader, sessionId string, opt ...db.Option) ([]*State, error) { const op = "session.fetchStates" var states []*State diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index b04475cab7..516c76281f 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -541,231 +541,16 @@ func TestRepository_updateState(t *testing.T) { } } -func TestRepository_AuthorizeConnect(t *testing.T) { - t.Parallel() - conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - kms := kms.TestKms(t, conn, wrapper) - repo, err := NewRepository(rw, rw, kms) - require.NoError(t, err) - - var testServer string - setupFn := func(exp *timestamp.Timestamp) *Session { - composedOf := TestSessionParams(t, conn, wrapper, iamRepo) - if exp != nil { - composedOf.ExpirationTime = exp - } - s := TestSession(t, conn, wrapper, composedOf) - srv := TestWorker(t, conn, wrapper) - testServer = srv.PrivateId - tofu := TestTofu(t) - _, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) - require.NoError(t, err) - return s - } - testSession := setupFn(nil) - - tests := []struct { - name string - session *Session - wantErr bool - wantIsError error - wantAuthzInfo ConnectionAuthzSummary - }{ - { - name: "valid", - session: testSession, - wantAuthzInfo: ConnectionAuthzSummary{ - ConnectionLimit: 1, - CurrentConnectionCount: 1, - ExpirationTime: testSession.ExpirationTime, - }, - }, - { - name: "empty-sessionId", - session: func() *Session { - s := AllocSession() - return &s - }(), - wantErr: true, - }, - { - name: "exceeded-connection-limit", - session: func() *Session { - session := setupFn(nil) - _ = TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - return session - }(), - wantErr: true, - }, - { - name: "expired-session", - session: setupFn(×tamp.Timestamp{Timestamp: timestamppb.Now()}), - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - - c, cs, authzInfo, err := repo.AuthorizeConnection(context.Background(), tt.session.PublicId, testServer) - if tt.wantErr { - require.Error(err) - // TODO (jimlambrt 9/2020): add in tests for errorsIs once we - // remove the grpc errors from the repo. - // if tt.wantIsError != nil { - // assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error()) - // } - return - } - require.NoError(err) - require.NotNil(c) - require.NotNil(cs) - assert.Equal(StatusAuthorized, cs[0].Status) - - assert.True(authzInfo.ExpirationTime.GetTimestamp().AsTime().Sub(tt.wantAuthzInfo.ExpirationTime.GetTimestamp().AsTime()) < 10*time.Millisecond) - tt.wantAuthzInfo.ExpirationTime = authzInfo.ExpirationTime - - assert.Equal(tt.wantAuthzInfo.ExpirationTime, authzInfo.ExpirationTime) - assert.Equal(tt.wantAuthzInfo.ConnectionLimit, authzInfo.ConnectionLimit) - assert.Equal(tt.wantAuthzInfo.CurrentConnectionCount, authzInfo.CurrentConnectionCount) - }) - } -} - -func TestRepository_ConnectConnection(t *testing.T) { - t.Parallel() - conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - kms := kms.TestKms(t, conn, wrapper) - repo, err := NewRepository(rw, rw, kms) - require.NoError(t, err) - - setupFn := func() ConnectWith { - s := TestDefaultSession(t, conn, wrapper, iamRepo) - srv := TestWorker(t, conn, wrapper) - tofu := TestTofu(t) - _, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) - require.NoError(t, err) - c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - return ConnectWith{ - ConnectionId: c.PublicId, - ClientTcpAddress: "127.0.0.1", - ClientTcpPort: 22, - EndpointTcpAddress: "127.0.0.1", - EndpointTcpPort: 2222, - UserClientIp: "127.0.0.1", - } - } - tests := []struct { - name string - connectWith ConnectWith - wantErr bool - wantIsError errors.Code - }{ - { - name: "valid", - connectWith: setupFn(), - }, - { - name: "empty-SessionId", - connectWith: func() ConnectWith { - cw := setupFn() - cw.ConnectionId = "" - return cw - }(), - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - { - name: "empty-ClientTcpAddress", - connectWith: func() ConnectWith { - cw := setupFn() - cw.ClientTcpAddress = "" - return cw - }(), - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - { - name: "empty-ClientTcpPort", - connectWith: func() ConnectWith { - cw := setupFn() - cw.ClientTcpPort = 0 - return cw - }(), - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - { - name: "empty-EndpointTcpAddress", - connectWith: func() ConnectWith { - cw := setupFn() - cw.EndpointTcpAddress = "" - return cw - }(), - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - { - name: "empty-EndpointTcpPort", - connectWith: func() ConnectWith { - cw := setupFn() - cw.EndpointTcpPort = 0 - return cw - }(), - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - { - name: "empty-UserClientIp", - connectWith: func() ConnectWith { - cw := setupFn() - cw.UserClientIp = "" - return cw - }(), - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - - c, cs, err := repo.ConnectConnection(context.Background(), tt.connectWith) - if tt.wantErr { - require.Error(err) - assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error()) - return - } - require.NoError(err) - require.NotNil(c) - require.NotNil(cs) - assert.Equal(StatusConnected, cs[0].Status) - gotConn, _, err := repo.LookupConnection(context.Background(), c.PublicId) - require.NoError(err) - assert.Equal(tt.connectWith.ClientTcpAddress, gotConn.ClientTcpAddress) - assert.Equal(tt.connectWith.ClientTcpPort, gotConn.ClientTcpPort) - assert.Equal(tt.connectWith.ClientTcpAddress, gotConn.ClientTcpAddress) - assert.Equal(tt.connectWith.EndpointTcpAddress, gotConn.EndpointTcpAddress) - assert.Equal(tt.connectWith.EndpointTcpPort, gotConn.EndpointTcpPort) - assert.Equal(tt.connectWith.UserClientIp, gotConn.UserClientIp) - }) - } -} - func TestRepository_TerminateCompletedSessions(t *testing.T) { t.Parallel() + ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) kms := kms.TestKms(t, conn, wrapper) repo, err := NewRepository(rw, rw, kms) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) require.NoError(t, err) setupFn := func(limit int32, expireIn time.Duration, leaveOpen bool) *Session { @@ -788,7 +573,7 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) { BytesDown: 1, ClosedReason: ConnectionClosedByUser, } - _, err = repo.CloseConnections(context.Background(), []CloseWith{cw}) + _, err = connRepo.closeConnections(context.Background(), []CloseWith{cw}) require.NoError(t, err) } return s @@ -967,10 +752,10 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) { } assert.Equal(args.wantTermed[found.PublicId].String(), found.TerminationReason) t.Logf("terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit) - conn, err := repo.ListConnectionsBySessionId(context.Background(), found.PublicId) + conn, err := connRepo.ListConnectionsBySessionId(context.Background(), found.PublicId) require.NoError(err) for _, sc := range conn { - c, cs, err := repo.LookupConnection(context.Background(), sc.PublicId) + c, cs, err := connRepo.LookupConnection(context.Background(), sc.PublicId) require.NoError(err) assert.NotEmpty(c.ClosedReason) for _, s := range cs { @@ -980,7 +765,7 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) { } else { t.Logf("not terminated %s has a connection limit of %d", found.PublicId, found.ConnectionLimit) assert.Equal("", found.TerminationReason) - conn, err := repo.ListConnectionsBySessionId(context.Background(), found.PublicId) + conn, err := connRepo.ListConnectionsBySessionId(context.Background(), found.PublicId) require.NoError(err) for _, sc := range conn { cs, err := fetchConnectionStates(context.Background(), rw, sc.PublicId) @@ -996,94 +781,16 @@ func TestRepository_TerminateCompletedSessions(t *testing.T) { } } -func TestRepository_CloseConnections(t *testing.T) { - t.Parallel() - conn, _ := db.TestSetup(t, "postgres") - rw := db.New(conn) - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - kms := kms.TestKms(t, conn, wrapper) - repo, err := NewRepository(rw, rw, kms) - require.NoError(t, err) - - setupFn := func(cnt int) []CloseWith { - s := TestDefaultSession(t, conn, wrapper, iamRepo) - srv := TestWorker(t, conn, wrapper) - tofu := TestTofu(t) - s, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) - require.NoError(t, err) - cw := make([]CloseWith, 0, cnt) - for i := 0; i < cnt; i++ { - c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") - require.NoError(t, err) - cw = append(cw, CloseWith{ - ConnectionId: c.PublicId, - BytesUp: 1, - BytesDown: 2, - ClosedReason: ConnectionClosedByUser, - }) - } - return cw - } - tests := []struct { - name string - closeWith []CloseWith - reason TerminationReason - wantErr bool - wantIsError errors.Code - }{ - { - name: "valid", - closeWith: setupFn(2), - reason: ClosedByUser, - }, - { - name: "empty-closed-with", - closeWith: []CloseWith{}, - reason: ClosedByUser, - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - { - name: "missing-ConnectionId", - closeWith: func() []CloseWith { - cw := setupFn(2) - cw[1].ConnectionId = "" - return cw - }(), - reason: ClosedByUser, - wantErr: true, - wantIsError: errors.InvalidParameter, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - resp, err := repo.CloseConnections(context.Background(), tt.closeWith) - if tt.wantErr { - require.Error(err) - assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error()) - return - } - require.NoError(err) - assert.Equal(len(tt.closeWith), len(resp)) - for _, r := range resp { - require.NotNil(r.Connection) - require.NotNil(r.ConnectionStates) - assert.Equal(StatusClosed, r.ConnectionStates[0].Status) - } - }) - } -} - func TestRepository_CancelSession(t *testing.T) { t.Parallel() + ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) wrapper := db.TestWrapper(t) iamRepo := iam.TestRepo(t, conn, wrapper) kms := kms.TestKms(t, conn, wrapper) repo, err := NewRepository(rw, rw, kms) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) require.NoError(t, err) setupFn := func() *Session { session := TestDefaultSession(t, conn, wrapper, iamRepo) @@ -1115,11 +822,13 @@ func TestRepository_CancelSession(t *testing.T) { BytesDown: 1, ClosedReason: ConnectionClosedByUser, } - _, err = repo.CloseConnections(context.Background(), []CloseWith{cw}) + _, err = CloseConnections(ctx, repo, connRepo, []CloseWith{cw}) require.NoError(t, err) - s, _, err := repo.LookupSession(context.Background(), session.PublicId) + s, _, err := repo.LookupSession(ctx, session.PublicId) require.NoError(t, err) assert.Equal(t, StatusTerminated, s.States[0].Status) + // The two transactions to cancel connections and terminate the session will result in version being 2, not 1 + session.Version = s.Version return session }(), wantStatus: StatusTerminated, @@ -1185,7 +894,7 @@ func TestRepository_CancelSession(t *testing.T) { default: version = tt.session.Version } - s, err := repo.CancelSession(context.Background(), id, version) + s, err := repo.CancelSession(ctx, id, version) if tt.wantErr { require.Error(err) assert.Truef(errors.Match(errors.T(tt.wantIsError), err), "unexpected error %s", err.Error()) diff --git a/internal/session/service_authorize_connection.go b/internal/session/service_authorize_connection.go new file mode 100644 index 0000000000..c5130e3ef5 --- /dev/null +++ b/internal/session/service_authorize_connection.go @@ -0,0 +1,29 @@ +package session + +import ( + "context" + + "github.com/hashicorp/boundary/internal/errors" +) + +// AuthorizeConnection is a domain service function that will create a Connection +// for a session if the following criteria are met: +// * The session is active. +// * The session is not expired. +// * The session has not reached its connection limit or has a connection limit of -1. +// If any of these criteria is not met, it returns an error with Code InvalidSessionState. +func AuthorizeConnection(ctx context.Context, sessionRepoFn *Repository, connectionRepoFn *ConnectionRepository, + sessionId, workerId string, opt ...Option) (*Connection, []*ConnectionState, *AuthzSummary, error) { + const op = "session.AuthorizeConnection" + + connection, connectionStates, err := connectionRepoFn.AuthorizeConnection(ctx, sessionId, workerId) + if err != nil { + return nil, nil, nil, errors.Wrap(ctx, err, op) + } + + authzSummary, err := sessionRepoFn.sessionAuthzSummary(ctx, sessionId) + if err != nil { + return nil, nil, nil, errors.Wrap(ctx, err, op) + } + return connection, connectionStates, authzSummary, nil +} diff --git a/internal/session/service_authorize_connection_test.go b/internal/session/service_authorize_connection_test.go new file mode 100644 index 0000000000..90d6b77446 --- /dev/null +++ b/internal/session/service_authorize_connection_test.go @@ -0,0 +1,112 @@ +package session + +import ( + "context" + "testing" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_AuthorizeConnection(t *testing.T) { + t.Parallel() + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) + require.NoError(t, err) + + var testServer string + setupFn := func(exp *timestamp.Timestamp) *Session { + composedOf := TestSessionParams(t, conn, wrapper, iamRepo) + if exp != nil { + composedOf.ExpirationTime = exp + } + s := TestSession(t, conn, wrapper, composedOf) + srv := TestWorker(t, conn, wrapper) + testServer = srv.PrivateId + tofu := TestTofu(t) + _, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) + require.NoError(t, err) + return s + } + testSession := setupFn(nil) + + tests := []struct { + name string + session *Session + wantErr bool + wantIsError error + wantAuthzInfo AuthzSummary + }{ + { + name: "valid", + session: testSession, + wantAuthzInfo: AuthzSummary{ + ConnectionLimit: 1, + CurrentConnectionCount: 1, + ExpirationTime: testSession.ExpirationTime, + }, + }, + { + name: "empty-sessionId", + session: func() *Session { + s := AllocSession() + return &s + }(), + wantErr: true, + }, + { + name: "exceeded-connection-limit", + session: func() *Session { + session := setupFn(nil) + _ = TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") + return session + }(), + wantErr: true, + }, + { + name: "expired-session", + session: setupFn(×tamp.Timestamp{Timestamp: timestamppb.Now()}), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + c, cs, authzInfo, err := AuthorizeConnection(context.Background(), repo, connRepo, tt.session.PublicId, testServer) + if tt.wantErr { + require.Error(err) + // TODO (jimlambrt 9/2020): add in tests for errorsIs once we + // remove the grpc errors from the repo. + // if tt.wantIsError != nil { + // assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error()) + // } + return + } + require.NoError(err) + require.NotNil(c) + require.NotNil(cs) + assert.Equal(StatusAuthorized, cs[0].Status) + + assert.True(authzInfo.ExpirationTime.GetTimestamp().AsTime().Sub(tt.wantAuthzInfo.ExpirationTime.GetTimestamp().AsTime()) < 10*time.Millisecond) + tt.wantAuthzInfo.ExpirationTime = authzInfo.ExpirationTime + + assert.Equal(tt.wantAuthzInfo.ExpirationTime, authzInfo.ExpirationTime) + assert.Equal(tt.wantAuthzInfo.ConnectionLimit, authzInfo.ConnectionLimit) + assert.Equal(tt.wantAuthzInfo.CurrentConnectionCount, authzInfo.CurrentConnectionCount) + }) + } +} diff --git a/internal/session/service_close_connections.go b/internal/session/service_close_connections.go new file mode 100644 index 0000000000..34971722ab --- /dev/null +++ b/internal/session/service_close_connections.go @@ -0,0 +1,31 @@ +package session + +import ( + "context" + + "github.com/hashicorp/boundary/internal/errors" +) + +// CloseConnections is a domain service function that: +// * closes requested connections +// * uses the sessionId of the connection to see if the session meets conditions for termination +func CloseConnections(ctx context.Context, sessionRepoFn *Repository, connectionRepoFn *ConnectionRepository, + closeWiths []CloseWith) ([]closeConnectionResp, error) { + const op = "session.AuthorizeConnection" + + closeInfos, err := connectionRepoFn.closeConnections(ctx, closeWiths) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + + // Attempt to terminate only once per sessionId + sessionIdsProcessed := make(map[string]bool) + for _, c := range closeInfos { + if !sessionIdsProcessed[c.Connection.SessionId] { + sessionRepoFn.terminateSessionIfPossible(ctx, c.Connection.SessionId) + sessionIdsProcessed[c.Connection.SessionId] = true + } + } + + return closeInfos, nil +} diff --git a/internal/session/service_close_connections_test.go b/internal/session/service_close_connections_test.go new file mode 100644 index 0000000000..8eb5d3fa61 --- /dev/null +++ b/internal/session/service_close_connections_test.go @@ -0,0 +1,97 @@ +package session + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServiceCloseConnections(t *testing.T) { + t.Parallel() + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + connRepo, err := NewConnectionRepository(ctx, rw, rw, kms) + require.NoError(t, err) + + type sessionAndCloseWiths struct { + session *Session + closeWith []CloseWith + } + + setupFn := func(cnt int, addtlConn int) sessionAndCloseWiths { + s := TestDefaultSession(t, conn, wrapper, iamRepo) + srv := TestWorker(t, conn, wrapper) + tofu := TestTofu(t) + s, _, err = repo.ActivateSession(context.Background(), s.PublicId, s.Version, srv.PrivateId, srv.Type, tofu) + require.NoError(t, err) + + require.NoError(t, err) + cw := make([]CloseWith, 0, cnt) + for i := 0; i < cnt; i++ { + c := TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") + require.NoError(t, err) + cw = append(cw, CloseWith{ + ConnectionId: c.PublicId, + BytesUp: 1, + BytesDown: 2, + ClosedReason: ConnectionClosedByUser, + }) + } + + for i := 0; i < addtlConn; i++ { + TestConnection(t, conn, s.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222, "127.0.0.1") + require.NoError(t, err) + } + return sessionAndCloseWiths{s, cw} + } + + tests := []struct { + name string + sessionCW sessionAndCloseWiths + wantClosedSession bool + }{ + { + name: "close-multiple-connections-and-session", + sessionCW: setupFn(4, 0), + wantClosedSession: true, + }, + { + name: "close-subset-of-connections", + sessionCW: setupFn(2, 1), + wantClosedSession: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + resp, err := CloseConnections(ctx, repo, connRepo, tt.sessionCW.closeWith) + require.NoError(err) + + for _, r := range resp { + require.NotNil(r.Connection) + require.NotNil(r.ConnectionStates) + assert.Equal(StatusClosed, r.ConnectionStates[0].Status) + } + + // Ensure session is in the state we want- terminated if all conns closed, else active + ses, _, err := repo.LookupSession(ctx, tt.sessionCW.session.PublicId) + require.NoError(err) + if tt.wantClosedSession { + assert.Equal(StatusTerminated, ses.States[0].Status) + } else { + assert.Equal(StatusActive, ses.States[0].Status) + } + }) + } +} diff --git a/internal/session/service_worker_status_report.go b/internal/session/service_worker_status_report.go new file mode 100644 index 0000000000..0cc62a426a --- /dev/null +++ b/internal/session/service_worker_status_report.go @@ -0,0 +1,48 @@ +package session + +import ( + "context" + "fmt" + + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/observability/event" +) + +// StateReport is used to report on the state of a Session. +type StateReport struct { + SessionId string + Status Status + ConnectionIds []string +} + +// WorkerStatusReport is a domain service function that compares the state of +// sessions and connections as reported by a Worker, to the known state in the +// repositories. It returns a []StateReport for each session that is in the +// canceling or terminated state. It also will check for any orphaned +// connections, which is defined as a connection that is in an active state, +// but was not reported by worker. Any orphaned connections will be marked as +// closed. +func WorkerStatusReport(ctx context.Context, repo *Repository, connRepo *ConnectionRepository, workerId string, report []StateReport) ([]StateReport, error) { + const op = "session.WorkerStatusReport" + + reportedConnections := make([]string, 0) + reportedSessions := make([]string, 0, len(report)) + for _, r := range report { + reportedSessions = append(reportedSessions, r.SessionId) + reportedConnections = append(reportedConnections, r.ConnectionIds...) + } + + notActive, err := repo.checkIfNoLongerActive(ctx, reportedSessions) + if err != nil { + return nil, errors.New(ctx, errors.Internal, op, fmt.Sprintf("Error checking session state for worker %s: %v", workerId, err)) + } + + closed, err := connRepo.closeOrphanedConnections(ctx, workerId, reportedConnections) + if err != nil { + return notActive, errors.New(ctx, errors.Internal, op, fmt.Sprintf("Error closing orphaned connections for worker %s: %v", workerId, err)) + } + if len(closed) > 0 { + event.WriteSysEvent(ctx, op, "marked unclaimed connections as closed", "server_id", workerId, "count", len(closed)) + } + return notActive, err +} diff --git a/internal/session/service_worker_status_report_test.go b/internal/session/service_worker_status_report_test.go new file mode 100644 index 0000000000..0589bdea35 --- /dev/null +++ b/internal/session/service_worker_status_report_test.go @@ -0,0 +1,394 @@ +package session_test + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/internal/authtoken" + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/host/static" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/servers" + "github.com/hashicorp/boundary/internal/session" + "github.com/hashicorp/boundary/internal/target" + "github.com/hashicorp/boundary/internal/target/tcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWorkerStatusReport(t *testing.T) { + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + kms := kms.TestKms(t, conn, wrapper) + org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + + serverRepo, _ := servers.NewRepository(rw, rw, kms) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_controller1", + Type: "controller", + Address: "127.0.0.1", + }) + serverRepo.UpsertServer(ctx, &servers.Server{ + PrivateId: "test_worker1", + Type: "worker", + Address: "127.0.0.1", + }) + + repo, err := session.NewRepository(rw, rw, kms) + require.NoError(t, err) + connRepo, err := session.NewConnectionRepository(ctx, rw, rw, kms, session.WithWorkerStateDelay(0)) + require.NoError(t, err) + + at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) + uId := at.GetIamUserId() + hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0] + hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0] + h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0] + static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h}) + tar := tcp.TestTarget( + ctx, + t, conn, prj.GetPublicId(), "test", + target.WithHostSources([]string{hs.GetPublicId()}), + target.WithSessionConnectionLimit(-1), + ) + + type testCase struct { + worker *servers.Server + req []session.StateReport + want []session.StateReport + orphanedConnections []string + } + cases := []struct { + name string + caseFn func(t *testing.T) testCase + }{ + { + name: "No Sessions", + caseFn: func(t *testing.T) testCase { + worker := session.TestWorker(t, conn, wrapper) + return testCase{ + worker: worker, + req: []session.StateReport{}, + want: []session.StateReport{}, + } + }, + }, + { + name: "No Sessions already canceled", + caseFn: func(t *testing.T) testCase { + worker := session.TestWorker(t, conn, wrapper) + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker.PrivateId, worker.Type, tofu) + require.NoError(t, err) + require.NoError(t, err) + + _, _, err = connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PrivateId) + require.NoError(t, err) + + _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) + require.NoError(t, err) + + return testCase{ + worker: worker, + req: []session.StateReport{}, + want: []session.StateReport{}, + } + }, + }, + { + name: "Still Active", + caseFn: func(t *testing.T) testCase { + worker := session.TestWorker(t, conn, wrapper) + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker.PrivateId, worker.Type, tofu) + require.NoError(t, err) + require.NoError(t, err) + + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PrivateId) + require.NoError(t, err) + return testCase{ + worker: worker, + req: []session.StateReport{ + { + SessionId: sess.PublicId, + Status: session.StatusActive, + ConnectionIds: []string{connection.PublicId}, + }, + }, + want: []session.StateReport{}, + } + }, + }, + { + name: "SessionClosed", + caseFn: func(t *testing.T) testCase { + worker := session.TestWorker(t, conn, wrapper) + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker.PrivateId, worker.Type, tofu) + require.NoError(t, err) + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PrivateId) + require.NoError(t, err) + _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) + require.NoError(t, err) + + return testCase{ + worker: worker, + req: []session.StateReport{ + { + SessionId: sess.PublicId, + Status: session.StatusActive, + ConnectionIds: []string{connection.PublicId}, + }, + }, + want: []session.StateReport{ + { + SessionId: sess.PublicId, + Status: session.StatusCanceling, + }, + }, + } + }, + }, + { + name: "MultipleSessionsClosed", + caseFn: func(t *testing.T) testCase { + worker := session.TestWorker(t, conn, wrapper) + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker.PrivateId, worker.Type, tofu) + require.NoError(t, err) + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PrivateId) + require.NoError(t, err) + _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) + require.NoError(t, err) + + sess2 := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu2 := session.TestTofu(t) + sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, worker.PrivateId, worker.Type, tofu2) + require.NoError(t, err) + connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PrivateId) + require.NoError(t, err) + _, err = repo.CancelSession(ctx, sess2.PublicId, sess2.Version) + require.NoError(t, err) + + return testCase{ + worker: worker, + req: []session.StateReport{ + { + SessionId: sess.PublicId, + Status: session.StatusActive, + ConnectionIds: []string{connection.PublicId}, + }, + { + SessionId: sess2.PublicId, + Status: session.StatusActive, + ConnectionIds: []string{connection2.PublicId}, + }, + }, + want: []session.StateReport{ + { + SessionId: sess.PublicId, + Status: session.StatusCanceling, + }, + { + SessionId: sess2.PublicId, + Status: session.StatusCanceling, + }, + }, + } + }, + }, + { + name: "OrphanedConnection", + caseFn: func(t *testing.T) testCase { + worker := session.TestWorker(t, conn, wrapper) + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker.PrivateId, worker.Type, tofu) + require.NoError(t, err) + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PrivateId) + require.NoError(t, err) + + sess2 := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu2 := session.TestTofu(t) + sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, worker.PrivateId, worker.Type, tofu2) + require.NoError(t, err) + connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PrivateId) + require.NoError(t, err) + require.NotEqual(t, connection.PublicId, connection2.PublicId) + + return testCase{ + worker: worker, + req: []session.StateReport{ + { + SessionId: sess2.PublicId, + Status: session.StatusActive, + ConnectionIds: []string{connection2.PublicId}, + }, + }, + want: []session.StateReport{}, + orphanedConnections: []string{connection.PublicId}, + } + }, + }, + { + name: "MultipleSessionsAndOrphanedConnections", + caseFn: func(t *testing.T) testCase { + worker := session.TestWorker(t, conn, wrapper) + sess := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu := session.TestTofu(t) + sess, _, err = repo.ActivateSession(ctx, sess.PublicId, sess.Version, worker.PrivateId, worker.Type, tofu) + require.NoError(t, err) + connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker.PrivateId) + require.NoError(t, err) + _, err = repo.CancelSession(ctx, sess.PublicId, sess.Version) + require.NoError(t, err) + + sess2 := session.TestSession(t, conn, wrapper, session.ComposedOf{ + UserId: uId, + HostId: h.GetPublicId(), + TargetId: tar.GetPublicId(), + HostSetId: hs.GetPublicId(), + AuthTokenId: at.GetPublicId(), + ScopeId: prj.GetPublicId(), + Endpoint: "tcp://127.0.0.1:22", + ConnectionLimit: 10, + }) + tofu2 := session.TestTofu(t) + sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, worker.PrivateId, worker.Type, tofu2) + require.NoError(t, err) + connection2, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PrivateId) + require.NoError(t, err) + connection3, _, err := connRepo.AuthorizeConnection(ctx, sess2.PublicId, worker.PrivateId) + require.NoError(t, err) + _, err = repo.CancelSession(ctx, sess2.PublicId, sess2.Version) + require.NoError(t, err) + + return testCase{ + worker: worker, + req: []session.StateReport{ + { + SessionId: sess.PublicId, + Status: session.StatusActive, + ConnectionIds: []string{connection.PublicId}, + }, + { + SessionId: sess2.PublicId, + Status: session.StatusActive, + ConnectionIds: []string{connection2.PublicId}, + }, + }, + want: []session.StateReport{ + { + SessionId: sess.PublicId, + Status: session.StatusCanceling, + }, + { + SessionId: sess2.PublicId, + Status: session.StatusCanceling, + }, + }, + orphanedConnections: []string{connection3.PublicId}, + } + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + tc := tt.caseFn(t) + + got, err := session.WorkerStatusReport(ctx, repo, connRepo, tc.worker.PrivateId, tc.req) + require.NoError(err) + assert.ElementsMatch(tc.want, got) + for _, dc := range tc.orphanedConnections { + gotConn, states, err := connRepo.LookupConnection(ctx, dc) + require.NoError(err) + assert.Equal(session.ConnectionSystemError, session.ClosedReason(gotConn.ClosedReason)) + assert.Equal(2, len(states)) + assert.Nil(states[0].EndTime) + assert.Equal(session.StatusClosed, states[0].Status) + } + }) + } +} diff --git a/internal/tests/cluster/session_cleanup_test.go b/internal/tests/cluster/session_cleanup_test.go index 1b3749c339..eeb3e5ca58 100644 --- a/internal/tests/cluster/session_cleanup_test.go +++ b/internal/tests/cluster/session_cleanup_test.go @@ -35,21 +35,14 @@ import ( // worker is managing the lifecycle of a connection and will properly // unclaim it closed once the connection resumes, ensuring the // connection is marked as closed on the worker. -// -// * controller: Here, the controller is the one doing the work. The -// connection will be open on the worker until status checks resume -// from the worker. At this point, the controller will request the -// status change on the worker, physically closing the connection -// there. type timeoutBurdenType string const ( - timeoutBurdenTypeDefault timeoutBurdenType = "default" - timeoutBurdenTypeWorker timeoutBurdenType = "worker" - timeoutBurdenTypeController timeoutBurdenType = "controller" + timeoutBurdenTypeDefault timeoutBurdenType = "default" + timeoutBurdenTypeWorker timeoutBurdenType = "worker" ) -var timeoutBurdenCases = []timeoutBurdenType{timeoutBurdenTypeDefault, timeoutBurdenTypeWorker, timeoutBurdenTypeController} +var timeoutBurdenCases = []timeoutBurdenType{timeoutBurdenTypeDefault, timeoutBurdenTypeWorker} func controllerGracePeriod(ty timeoutBurdenType) time.Duration { if ty == timeoutBurdenTypeWorker { @@ -60,10 +53,6 @@ func controllerGracePeriod(ty timeoutBurdenType) time.Duration { } func workerGracePeriod(ty timeoutBurdenType) time.Duration { - if ty == timeoutBurdenTypeController { - return helper.DefaultGracePeriod * 10 - } - return helper.DefaultGracePeriod } @@ -170,30 +159,16 @@ func testWorkerSessionCleanupSingle(burdenCase timeoutBurdenType) func(t *testin case timeoutBurdenTypeWorker: // Wait on worker, then check controller sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusConnected) - - case timeoutBurdenTypeController: - // Wait on controller, then check worker - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusClosed) - sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected) + sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().ConnectionRepoFn, session.StatusConnected) default: // Should be closed on both worker and controller. Wait on // worker then check controller. sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().ConnectionRepoFn, session.StatusClosed) } - // Run send/receive test again to check expected connection-level - // behavior - if burdenCase == timeoutBurdenTypeController { - // Burden on controller, should be successful until connection - // resumes - sConn.TestSendRecvAll(t) - } else { - // Connection should die in other cases - sConn.TestSendRecvFail(t) - } + sConn.TestSendRecvFail(t) // Resume the connection, and reconnect. event.WriteSysEvent(ctx, op, "resuming controller/worker link") @@ -211,13 +186,7 @@ func testWorkerSessionCleanupSingle(burdenCase timeoutBurdenType) func(t *testin // a connection status, ensure that our old session's // connections are actually closed now that the worker is // properly reporting in again. - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusClosed) - - case timeoutBurdenTypeController: - // If we are expecting the controller to be the source of - // truth, the connection should now be forcibly closed after - // the worker gets a status change request back. - sConn.TestSendRecvFail(t) + sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().ConnectionRepoFn, session.StatusClosed) } // Proceed with new connection test @@ -371,30 +340,18 @@ func testWorkerSessionCleanupMulti(burdenCase timeoutBurdenType) func(t *testing case timeoutBurdenTypeWorker: // Wait on worker, then check controller sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusConnected) - - case timeoutBurdenTypeController: - // Wait on controller, then check worker - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusClosed) - sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusConnected) + sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().ConnectionRepoFn, session.StatusConnected) default: // Should be closed on both worker and controller. Wait on // worker then check controller. sess.ExpectConnectionStateOnWorker(ctx, t, w1, session.StatusClosed) - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusClosed) + sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().ConnectionRepoFn, session.StatusClosed) } // Run send/receive test again to check expected connection-level // behavior - if burdenCase == timeoutBurdenTypeController { - // Burden on controller, should be successful until connection - // resumes - sConn.TestSendRecvAll(t) - } else { - // Connection should die in other cases - sConn.TestSendRecvFail(t) - } + sConn.TestSendRecvFail(t) // Finally resume both, try again. Should behave as per normal. event.WriteSysEvent(ctx, op, "resuming connections to both controllers") @@ -413,13 +370,7 @@ func testWorkerSessionCleanupMulti(burdenCase timeoutBurdenType) func(t *testing // a connection status, ensure that our old session's // connections are actually closed now that the worker is // properly reporting in again. - sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().SessionRepoFn, session.StatusClosed) - - case timeoutBurdenTypeController: - // If we are expecting the controller to be the source of - // truth, the connection should now be forcibly closed after - // the worker gets a status change request back. - sConn.TestSendRecvFail(t) + sess.ExpectConnectionStateOnController(ctx, t, c1.Controller().ConnectionRepoFn, session.StatusClosed) } // Proceed with new connection test diff --git a/internal/tests/helper/testing_helper.go b/internal/tests/helper/testing_helper.go index 5697abf1dd..9ee2cad148 100644 --- a/internal/tests/helper/testing_helper.go +++ b/internal/tests/helper/testing_helper.go @@ -168,7 +168,7 @@ func (s *TestSession) connect(ctx context.Context, t *testing.T) net.Conn { func (s *TestSession) ExpectConnectionStateOnController( ctx context.Context, t *testing.T, - sessionRepoFn common.SessionRepoFactory, + connectionRepoFn common.ConnectionRepoFactory, expectState session.ConnectionStatus, ) { t.Helper() @@ -181,10 +181,10 @@ func (s *TestSession) ExpectConnectionStateOnController( // This is just for initialization of the actual state set. const sessionStatusUnknown session.ConnectionStatus = "unknown" - sessionRepo, err := sessionRepoFn() + connectionRepo, err := connectionRepoFn() require.NoError(err) - conns, err := sessionRepo.ListConnectionsBySessionId(ctx, s.sessionId) + conns, err := connectionRepo.ListConnectionsBySessionId(ctx, s.sessionId) require.NoError(err) // To avoid misleading passing tests, we require this test be used // with sessions with connections.. @@ -208,7 +208,7 @@ func (s *TestSession) ExpectConnectionStateOnController( } for i, conn := range conns { - _, states, err := sessionRepo.LookupConnection(ctx, conn.PublicId, nil) + _, states, err := connectionRepo.LookupConnection(ctx, conn.PublicId, nil) require.NoError(err) // Look at the first state in the returned list, which will // be the most recent state. diff --git a/website/content/docs/concepts/domain-model/session-connections.mdx b/website/content/docs/concepts/domain-model/session-connections.mdx new file mode 100644 index 0000000000..596e068b3d --- /dev/null +++ b/website/content/docs/concepts/domain-model/session-connections.mdx @@ -0,0 +1,26 @@ +--- +layout: docs +page_title: Domain Model - Session Connections +description: |- + The anatomy of a Boundary session connection +--- + +# Session Connections +A session connection represents an authorized proxy between a [user][] and a [host][]. After the creation of a +[session][], a user initiates a connection to a [target][] using the Boundary-provided +proxy information and [credentials][] (if applicable). + +Users can create multiple connections to a [target][], so long as the [session][] has not expired or reached its maximum +number of connections. + +Session connections terminate on user exit from the proxy, or on termination of the [session][]. + +## Referenced By + +- [Session][] + +[credentials]: /docs/concepts/domain-model/credentials +[host]: /docs/concepts/domain-model/hosts +[session]: /docs/concepts/domain-model/sessions +[target]: /docs/concepts/domain-model/targets +[user]: /docs/concepts/domain-model/users diff --git a/website/content/docs/concepts/domain-model/sessions.mdx b/website/content/docs/concepts/domain-model/sessions.mdx index f66932a367..3d31974e51 100644 --- a/website/content/docs/concepts/domain-model/sessions.mdx +++ b/website/content/docs/concepts/domain-model/sessions.mdx @@ -8,7 +8,7 @@ description: |- # Sessions A session is -a set of related connections +a set of related [connections][] between a [user][] and a [host][]. A session may include a set of [credentials][] which define the permissions granted to the [user][] on the [host][] for the duration @@ -74,6 +74,7 @@ Changes to a user's permissions do not effect existing sessions. [accounts]: /docs/concepts/domain-model/accounts [authentication method]: /docs/concepts/domain-model/auth-methods [authentication methods]: /docs/concepts/domain-model/auth-methods +[connections]: /docs/concepts/domain-model/session-connections [credential library]: /docs/concepts/domain-model/credential-libraries [credential libraries]: /docs/concepts/domain-model/credential-libraries [credential store]: /docs/concepts/domain-model/credential-stores diff --git a/website/data/docs-nav-data.json b/website/data/docs-nav-data.json index 349ee5e388..88803af4d1 100644 --- a/website/data/docs-nav-data.json +++ b/website/data/docs-nav-data.json @@ -193,6 +193,10 @@ "title": "Sessions", "path": "concepts/domain-model/sessions" }, + { + "title": "Session Connections", + "path": "concepts/domain-model/session-connections" + }, { "title": "Targets", "path": "concepts/domain-model/targets"