Skip to content

Commit

Permalink
Merge pull request #342 from matrix-org/dmr/preemptive-bans
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson authored Oct 17, 2023
2 parents 0c3a2b9 + 47b7fac commit c3164d6
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 21 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
/syncv3
node_modules

# Go workspaces
go.work
go.work.sum
25 changes: 7 additions & 18 deletions sync3/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ func (d *Dispatcher) OnNewEvent(
targetUser := ""
membership := ""
shouldForceInitial := false
leaveAfterJoinOrInvite := false
if ed.EventType == "m.room.member" && ed.StateKey != nil {
targetUser = *ed.StateKey
membership = ed.Content.Get("membership").Str
Expand All @@ -173,7 +174,7 @@ func (d *Dispatcher) OnNewEvent(
case "ban":
fallthrough
case "leave":
d.jrt.UserLeftRoom(targetUser, ed.RoomID)
leaveAfterJoinOrInvite = d.jrt.UserLeftRoom(targetUser, ed.RoomID)
}
ed.InviteCount = d.jrt.NumInvitedUsersForRoom(ed.RoomID)
}
Expand All @@ -186,6 +187,11 @@ func (d *Dispatcher) OnNewEvent(
return d.ReceiverForUser(userID) != nil
})
ed.JoinCount = joinCount
if leaveAfterJoinOrInvite {
// Only tell the target user about a leave if they were previously aware of the
// room. This prevents us from leaking pre-emptive bans.
userIDs = append(userIDs, targetUser)
}
d.notifyListeners(ctx, ed, userIDs, targetUser, shouldForceInitial, membership)
}

Expand Down Expand Up @@ -256,35 +262,18 @@ func (d *Dispatcher) notifyListeners(ctx context.Context, ed *caches.EventData,
}

// per-user listeners
notifiedTarget := false
for _, userID := range userIDs {
l := d.userToReceiver[userID]
if l != nil {
edd := *ed
if targetUser == userID {
notifiedTarget = true
if shouldForceInitial {
edd.ForceInitial = true
}
}
l.OnNewEvent(ctx, &edd)
}
}
if targetUser != "" && !notifiedTarget { // e.g invites/leaves where you aren't joined yet but need to know about it
// We expect invites to come down the invitee's poller, which triggers OnInvite code paths and
// not normal event codepaths. We need the separate code path to ensure invite stripped state
// is sent to the conn and not live data. Hence, if we get the invite event early from a different
// connection, do not send it to the target, as they must wait for the invite on their poller.
if membership != "invite" {
if shouldForceInitial {
ed.ForceInitial = true
}
l := d.userToReceiver[targetUser]
if l != nil {
l.OnNewEvent(ctx, ed)
}
}
}
}

func (d *Dispatcher) OnInvalidateRoom(ctx context.Context, roomID string) {
Expand Down
15 changes: 12 additions & 3 deletions sync3/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,28 @@ func (t *JoinedRoomsTracker) UsersJoinedRoom(userIDs []string, roomID string) bo
}

// UserLeftRoom marks the given user as having left the given room.
func (t *JoinedRoomsTracker) UserLeftRoom(userID, roomID string) {
// Returns true if this user _was_ joined or invited to the room before this call,
// and false otherwise.
func (t *JoinedRoomsTracker) UserLeftRoom(userID, roomID string) bool {
t.mu.Lock()
defer t.mu.Unlock()
joinedRooms := t.userIDToJoinedRooms[userID]
delete(joinedRooms, roomID)
joinedUsers := t.roomIDToJoinedUsers[roomID]
delete(joinedUsers, userID)
invitedUsers := t.roomIDToInvitedUsers[roomID]

_, wasJoined := joinedUsers[userID]
_, wasInvited := invitedUsers[userID]

delete(joinedRooms, roomID)
delete(joinedUsers, userID)
delete(invitedUsers, userID)
t.userIDToJoinedRooms[userID] = joinedRooms
t.roomIDToJoinedUsers[roomID] = joinedUsers
t.roomIDToInvitedUsers[roomID] = invitedUsers

return wasJoined || wasInvited
}

func (t *JoinedRoomsTracker) JoinedRoomsForUser(userID string) []string {
t.mu.RLock()
defer t.mu.RUnlock()
Expand Down
58 changes: 58 additions & 0 deletions sync3/tracker_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sync3

import (
"fmt"
"sort"
"testing"
)
Expand Down Expand Up @@ -82,6 +83,63 @@ func TestTrackerStartup(t *testing.T) {
assertInt(t, jrt.NumInvitedUsersForRoom(roomC), 0)
}

func TestJoinedRoomsTracker_UserLeftRoom_ReturnValue(t *testing.T) {
alice := "@alice"
bob := "@bob"

// Tell the tracker that alice left various rooms. Assert its return value is sensible.

tcs := []struct {
id string
joined []string
invited []string
expectedResult bool
}{
{
id: "!a",
joined: []string{alice, bob},
invited: nil,
expectedResult: true,
},
{
id: "!b",
joined: []string{alice},
invited: nil,
expectedResult: true,
},
{
id: "!c",
joined: []string{bob},
invited: nil,
expectedResult: false,
},
{
id: "!d",
joined: nil,
invited: nil,
expectedResult: false,
},
{
id: "!e",
joined: nil,
invited: []string{alice},
expectedResult: true,
},
}

jrt := NewJoinedRoomsTracker()
for _, tc := range tcs {
jrt.UsersJoinedRoom(tc.joined, tc.id)
jrt.UsersInvitedToRoom(tc.invited, tc.id)
}

// Tell the tracker that Alice left every room. Check the return value is sensible.
for _, tc := range tcs {
wasJoinedOrInvited := jrt.UserLeftRoom(alice, tc.id)
assertBool(t, fmt.Sprintf("wasJoinedOrInvited[%s]", tc.id), wasJoinedOrInvited, tc.expectedResult)
}
}

func assertBool(t *testing.T, msg string, got, want bool) {
t.Helper()
if got != want {
Expand Down
58 changes: 58 additions & 0 deletions tests-e2e/membership_transitions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,61 @@ func TestMemberCounts(t *testing.T) {
},
}))
}

func TestPreemptiveBanIsNotLeaked(t *testing.T) {
alice := registerNamedUser(t, "alice")
nigel := registerNamedUser(t, "nigel")

t.Log("Alice creates a public room and a DM with Nigel.")
public := alice.MustCreateRoom(t, map[string]interface{}{"preset": "public_chat"})
dm := alice.MustCreateRoom(t, map[string]interface{}{"preset": "private_chat", "invite": []string{nigel.UserID}})

t.Log("Nigel joins the DM")
nigel.JoinRoom(t, dm, nil)

t.Log("Alice sends a sentinel message into the DM.")
dmSentinel := alice.SendEventSynced(t, dm, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{"body": "sentinel, sentinel, where have you been?", "msgtype": "m.text"},
})

t.Log("Nigel does an initial sliding sync.")
nigelRes := nigel.SlidingSync(t, sync3.Request{
Lists: map[string]sync3.RequestList{
"a": {
RoomSubscription: sync3.RoomSubscription{
TimelineLimit: 20,
},
Ranges: sync3.SliceRanges{{0, 10}},
},
},
})
t.Log("Nigel sees the sentinel.")
m.MatchResponse(t, nigelRes, m.MatchRoomSubscription(dm, MatchRoomTimelineMostRecent(1, []Event{{ID: dmSentinel}})))

t.Log("Alice pre-emptively bans Nigel from the public room.")
alice.MustDo(t, "POST", []string{"_matrix", "client", "v3", "rooms", public, "ban"},
client.WithJSONBody(t, map[string]any{"user_id": nigel.UserID}))

t.Log("Alice sliding syncs until she sees the ban.")
alice.SlidingSyncUntilMembership(t, "", public, nigel, "ban")

t.Log("Alice sends a second sentinel in Nigel's DM.")
dmSentinel2 := alice.SendEventSynced(t, dm, b.Event{
Type: "m.room.message",
Content: map[string]interface{}{"body": "sentinel 2 placeholder boogaloo", "msgtype": "m.text"},
})

t.Log("Nigel syncs until he sees the second sentinel. He should NOT see his ban event.")

nigelRes = nigel.SlidingSyncUntil(t, nigelRes.Pos, sync3.Request{}, func(response *sync3.Response) error {
seenPublicRoom := m.MatchRoomSubscription(public)
if seenPublicRoom(response) == nil {
t.Errorf("Nigel had a room subscription for the public room, but shouldn't have.")
m.LogResponse(t)(response)
t.FailNow()
}
seenSentinel := m.MatchRoomSubscription(dm, MatchRoomTimelineMostRecent(1, []Event{{ID: dmSentinel2}}))
return seenSentinel(response)
})
}

0 comments on commit c3164d6

Please sign in to comment.