Skip to content

Commit

Permalink
Prevent racy access to session parties
Browse files Browse the repository at this point in the history
Prefer using session.getParties instead of using session.parties
directly to prevent races when new parties are added. This also
prevents holding the mutex locking the parties when performing
external io operations.
  • Loading branch information
rosstimothy committed Oct 22, 2024
1 parent 0437596 commit d5cede4
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, re
return trace.Wrap(err)
}

for _, p := range session.parties {
for _, p := range session.getParties() {
// Send the message as a global request.
_, _, err = p.sconn.SendRequest(teleport.SessionEvent, false, eventPayload)
if err != nil {
Expand Down Expand Up @@ -1090,7 +1090,7 @@ func (s *session) emitSessionJoinEvent(ctx *ServerContext) {

// Notify all members of the party that a new member has joined over the
// "x-teleport-event" channel.
for _, p := range s.parties {
for _, p := range s.getParties() {
if len(notifyPartyPayload) == 0 {
s.logger.WarnContext(ctx.srv.Context(), "No session join event to send to party.", "party", p)
continue
Expand Down Expand Up @@ -1136,7 +1136,7 @@ func (s *session) emitSessionLeaveEvent(ctx *ServerContext) {

// Notify all members of the party that a new member has left over the
// "x-teleport-event" channel.
for _, p := range s.parties {
for _, p := range s.getParties() {
eventPayload, err := utils.FastMarshal(sessionLeaveEvent)
if err != nil {
s.logger.WarnContext(ctx.srv.Context(), "Unable to marshal session leave event for party.", "error", err, "party", p)
Expand Down Expand Up @@ -1588,11 +1588,8 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
}

func (s *session) broadcastResult(r ExecResult) {
s.mu.Lock()
defer s.mu.Unlock()

payload := ssh.Marshal(struct{ C uint32 }{C: uint32(r.Code)})
for _, p := range s.parties {
for _, p := range s.getParties() {
if _, err := p.ch.SendRequest("exit-status", false, payload); err != nil {
s.logger.InfoContext(
s.serverCtx, "Failed to send exit status",
Expand All @@ -1604,7 +1601,7 @@ func (s *session) broadcastResult(r ExecResult) {
}

func (s *session) String() string {
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.parties))
return fmt.Sprintf("session(id=%v, parties=%v)", s.id, len(s.getParties()))
}

// removePartyUnderLock removes the party from the in-memory map that holds all party members
Expand Down Expand Up @@ -1950,7 +1947,7 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP
func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) {
var participants []auth.SessionAccessContext

for _, party := range s.parties {
for _, party := range s.getParties() {
if party.ctx.Identity.TeleportUser == s.initiator {
continue
}
Expand Down

0 comments on commit d5cede4

Please sign in to comment.