Skip to content

Commit

Permalink
Prevent racy access to session parties
Browse files Browse the repository at this point in the history
Prefer using session.getParties instead of using session.parties
directly to prevent races when new parties are added. This also
prevents holding the mutex locking the parties when performing
external io operations.
  • Loading branch information
rosstimothy committed Oct 22, 2024
1 parent 0437596 commit 0ad644d
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,9 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann
return trace.Wrap(err)
}

canStart, _, err := sess.checkIfStart()
sess.mu.Lock()
canStart, _, err := sess.checkIfStartUnderLock()
sess.mu.Unlock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -500,7 +502,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro
sess.fileTransferReq = nil

sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser)
_ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx)
_ = s.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx)

return false, trace.AccessDenied("Teleport user does not match original requester")
}
Expand Down Expand Up @@ -533,9 +535,9 @@ const (
FileTransferDenied FileTransferRequestEvent = "file_transfer_request_deny"
)

// NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied.
// notifyFileTransferRequestUnderLock is called to notify all members of a party that a file transfer request has been created/approved/denied.
// The notification is a global ssh request and requires the client to update its UI state accordingly.
func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error {
func (s *SessionRegistry) notifyFileTransferRequestUnderLock(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error {
session := scx.getSession()
if session == nil {
s.logger.DebugContext(
Expand Down Expand Up @@ -1090,7 +1092,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {

// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
for _, p := range s.parties {
for _, p := range s.getParties() {
if len(notifyPartyPayload) == 0 {
s.logger.WarnContext(ctx.srv.Context(), "No session join event to send to party.", "party", p)
continue
Expand All @@ -1108,10 +1110,10 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {
}
}

// emitSessionLeaveEvent emits a session leave event to both the Audit Log as
// emitSessionLeaveEventUnderLock emits a session leave event to both the Audit Log as
// well as sending a "x-teleport-event" global request on the SSH connection.
// Must be called under session Lock.
func (s *session) emitSessionLeaveEvent(ctx *ServerContext) {
func (s *session) emitSessionLeaveEventUnderLock(ctx *ServerContext) {
sessionLeaveEvent := &apievents.SessionLeave{
Metadata: apievents.Metadata{
Type: events.SessionLeaveEvent,
Expand Down Expand Up @@ -1317,7 +1319,9 @@ func (s *session) launch() {
// startInteractive starts a new interactive process (or a shell) in the
// current session.
func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *party) error {
canStart, _, err := s.checkIfStart()
s.mu.Lock()
canStart, _, err := s.checkIfStartUnderLock()
s.mu.Unlock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1588,11 +1592,8 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
}

func (s *session) broadcastResult(r ExecResult) {
s.mu.Lock()
defer s.mu.Unlock()

payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)})
for _, p := range s.parties {
for _, p := range s.getParties() {
if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil {
s.logger.InfoContext(
s.serverCtx, "Failed to send exit status",
Expand All @@ -1604,7 +1605,7 @@ func (s *session) broadcastResult(r ExecResult) {
}

func (s *session) String() string {
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties))
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties()))
}

// removePartyUnderLock removes the party from the in-memory map that holds all party members
Expand All @@ -1630,9 +1631,9 @@ func (s *session) removePartyUnderLock(p *party) error {

// Emit session leave event to both the Audit Log and over the
// "x-teleport-event" channel in the SSH connection.
s.emitSessionLeaveEvent(p.ctx)
s.emitSessionLeaveEventUnderLock(p.ctx)

canRun, policyOptions, err := s.checkIfStart()
canRun, policyOptions, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1866,7 +1867,7 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar
} else {
s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location)
}
err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx)
err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx)

return trace.Wrap(err)
}
Expand Down Expand Up @@ -1909,7 +1910,7 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi
} else {
eventType = FileTransferUpdate
}
err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx)
err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, eventType, scx)

return trace.Wrap(err)
}
Expand Down Expand Up @@ -1942,12 +1943,15 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP
s.fileTransferReq = nil

s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID)
err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx)
err := s.registry.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx)

return trace.Wrap(err)
}

func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) {
// checkIfStartUnderLock determines if any moderation policies associated with
// the session are satisfied.
// Must be called under session Lock.
func (s *session) checkIfStartUnderLock() (bool, auth.PolicyOptions, error) {
var participants []auth.SessionAccessContext

for _, party := range s.parties {
Expand Down Expand Up @@ -1986,7 +1990,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
}

if len(s.parties) == 0 {
canStart, _, err := s.checkIfStart()
canStart, _, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -2050,7 +2054,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
}

if s.tracker.GetState() == types.SessionState_SessionStatePending {
canStart, _, err := s.checkIfStart()
canStart, _, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down

0 comments on commit 0ad644d

Please sign in to comment.