Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent racy access to session parties #47816

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,9 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann
return trace.Wrap(err)
}

canStart, _, err := sess.checkIfStart()
sess.mu.Lock()
canStart, _, err := sess.checkIfStartUnderLock()
sess.mu.Unlock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -500,7 +502,7 @@ func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, erro
sess.fileTransferReq = nil

sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser)
_ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx)
_ = s.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx)

return false, trace.AccessDenied("Teleport user does not match original requester")
}
Expand Down Expand Up @@ -533,9 +535,9 @@ const (
FileTransferDenied FileTransferRequestEvent = "file_transfer_request_deny"
)

// NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied.
// notifyFileTransferRequestUnderLock is called to notify all members of a party that a file transfer request has been created/approved/denied.
// The notification is a global ssh request and requires the client to update its UI state accordingly.
func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error {
func (s *SessionRegistry) notifyFileTransferRequestUnderLock(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error {
session := scx.getSession()
if session == nil {
s.logger.DebugContext(
Expand Down Expand Up @@ -1090,7 +1092,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {

// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
for _, p := range s.parties {
for _, p := range s.getParties() {
if len(notifyPartyPayload) == 0 {
s.logger.WarnContext(ctx.srv.Context(), "No session join event to send to party.", "party", p)
continue
Expand All @@ -1108,10 +1110,10 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {
}
}

// emitSessionLeaveEvent emits a session leave event to both the Audit Log as
// emitSessionLeaveEventUnderLock emits a session leave event to both the Audit Log as
// well as sending a "x-teleport-event" global request on the SSH connection.
// Must be called under session Lock.
func (s *session) emitSessionLeaveEvent(ctx *ServerContext) {
func (s *session) emitSessionLeaveEventUnderLock(ctx *ServerContext) {
sessionLeaveEvent := &apievents.SessionLeave{
Metadata: apievents.Metadata{
Type: events.SessionLeaveEvent,
Expand Down Expand Up @@ -1317,7 +1319,9 @@ func (s *session) launch() {
// startInteractive starts a new interactive process (or a shell) in the
// current session.
func (s *session) startInteractive(ctx context.Context, scx *ServerContext, p *party) error {
canStart, _, err := s.checkIfStart()
s.mu.Lock()
canStart, _, err := s.checkIfStartUnderLock()
s.mu.Unlock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1588,11 +1592,8 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
}

func (s *session) broadcastResult(r ExecResult) {
s.mu.Lock()
defer s.mu.Unlock()

payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)})
for _, p := range s.parties {
for _, p := range s.getParties() {
if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil {
s.logger.InfoContext(
s.serverCtx, "Failed to send exit status",
Expand All @@ -1604,7 +1605,7 @@ func (s *session) broadcastResult(r ExecResult) {
}

func (s *session) String() string {
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties))
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties()))
}

// removePartyUnderLock removes the party from the in-memory map that holds all party members
Expand All @@ -1630,9 +1631,9 @@ func (s *session) removePartyUnderLock(p *party) error {

// Emit session leave event to both the Audit Log and over the
// "x-teleport-event" channel in the SSH connection.
s.emitSessionLeaveEvent(p.ctx)
s.emitSessionLeaveEventUnderLock(p.ctx)

canRun, policyOptions, err := s.checkIfStart()
canRun, policyOptions, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1866,7 +1867,7 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar
} else {
s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location)
}
err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx)
err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx)

return trace.Wrap(err)
}
Expand Down Expand Up @@ -1909,7 +1910,7 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi
} else {
eventType = FileTransferUpdate
}
err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx)
err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, eventType, scx)

return trace.Wrap(err)
}
Expand Down Expand Up @@ -1942,12 +1943,15 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP
s.fileTransferReq = nil

s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID)
err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx)
err := s.registry.notifyFileTransferRequestUnderLock(req, FileTransferDenied, scx)

return trace.Wrap(err)
}

func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) {
// checkIfStartUnderLock determines if any moderation policies associated with
// the session are satisfied.
// Must be called under session Lock.
func (s *session) checkIfStartUnderLock() (bool, auth.PolicyOptions, error) {
var participants []auth.SessionAccessContext

for _, party := range s.parties {
Expand Down Expand Up @@ -1986,7 +1990,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
}

if len(s.parties) == 0 {
canStart, _, err := s.checkIfStart()
canStart, _, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -2050,7 +2054,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
}

if s.tracker.GetState() == types.SessionState_SessionStatePending {
canStart, _, err := s.checkIfStart()
canStart, _, err := s.checkIfStartUnderLock()
if err != nil {
return trace.Wrap(err)
}
Expand Down
Loading