diff --git a/lib/srv/sess.go b/lib/srv/sess.go index d61ed96ea4af..a0eea46511f8 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -413,7 +413,9 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann return trace.Wrap(err) } - canStart, _, err := sess.checkIfStart() + sess.mu.Lock() + canStart, _, err := sess.checkIfStartUnderLock() + sess.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -500,7 +502,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro sess.fileTransferReq = nil sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser) - _ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx) + _ = s.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return false, trace.AccessDenied("Teleport user does not match original requester") } @@ -533,9 +535,9 @@ const ( FileTransferDenied FileTransferRequestEvent = "file_transfer_request_deny" ) -// NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied. +// notifyFileTransferRequestUnderLock is called to notify all members of a party that a file transfer request has been created/approved/denied. // The notification is a global ssh request and requires the client to update its UI state accordingly. -func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { +func (s *SessionRegistry) notifyFileTransferRequestUnderLock(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { session := scx.getSession() if session == nil { s.logger.DebugContext( @@ -1090,7 +1092,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { // Notify all members of the party that a new member has joined over the // "x-teleport-event" channel. - for _, p := range s.parties { + for _, p := range s.getParties() { if len(notifyPartyPayload) == 0 { s.logger.WarnContext(ctx.srv.Context(), "No session join event to send to party.", "party", p) continue @@ -1108,10 +1110,10 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) { } } -// emitSessionLeaveEvent emits a session leave event to both the Audit Log as +// emitSessionLeaveEventUnderLock emits a session leave event to both the Audit Log as // well as sending a "x-teleport-event" global request on the SSH connection. // Must be called under session Lock. -func (s *session) emitSessionLeaveEvent(ctx *ServerContext) { +func (s *session) emitSessionLeaveEventUnderLock(ctx *ServerContext) { sessionLeaveEvent := &apievents.SessionLeave{ Metadata: apievents.Metadata{ Type: events.SessionLeaveEvent, @@ -1317,7 +1319,9 @@ func (s *session) launch() { // startInteractive starts a new interactive process (or a shell) in the // current session. func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *party) error { - canStart, _, err := s.checkIfStart() + s.mu.Lock() + canStart, _, err := s.checkIfStartUnderLock() + s.mu.Unlock() if err != nil { return trace.Wrap(err) } @@ -1588,11 +1592,8 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve } func (s *session) broadcastResult(r ExecResult) { - s.mu.Lock() - defer s.mu.Unlock() - payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)}) - for _, p := range s.parties { + for _, p := range s.getParties() { if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil { s.logger.InfoContext( s.serverCtx, "Failed to send exit status", @@ -1604,7 +1605,7 @@ func (s *session) broadcastResult(r ExecResult) { } func (s *session) String() string { - return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties)) + return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties())) } // removePartyUnderLock removes the party from the in-memory map that holds all party members @@ -1630,9 +1631,9 @@ func (s *session) removePartyUnderLock(p *party) error { // Emit session leave event to both the Audit Log and over the // "x-teleport-event" channel in the SSH connection. - s.emitSessionLeaveEvent(p.ctx) + s.emitSessionLeaveEventUnderLock(p.ctx) - canRun, policyOptions, err := s.checkIfStart() + canRun, policyOptions, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -1866,7 +1867,7 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar } else { s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location) } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx) return trace.Wrap(err) } @@ -1909,7 +1910,7 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi } else { eventType = FileTransferUpdate } - err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx) + err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, eventType, scx) return trace.Wrap(err) } @@ -1942,12 +1943,15 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP s.fileTransferReq = nil s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID) - err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx) + err := s.registry.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx) return trace.Wrap(err) } -func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) { +// checkIfStartUnderLock determines if any moderation policies associated with +// the session are satisfied. +// Must be called under session Lock. +func (s *session) checkIfStartUnderLock() (bool, auth.PolicyOptions, error) { var participants []auth.SessionAccessContext for _, party := range s.parties { @@ -1986,7 +1990,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if len(s.parties) == 0 { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) } @@ -2050,7 +2054,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { } if s.tracker.GetState() == types.SessionState_SessionStatePending { - canStart, _, err := s.checkIfStart() + canStart, _, err := s.checkIfStartUnderLock() if err != nil { return trace.Wrap(err) }