Skip to content

Commit

Permalink
Merge pull request #382 from matrix-org/kegan/conn-map-tests
Browse files Browse the repository at this point in the history
bugfix: when connections expire, only delete the affected connection
  • Loading branch information
kegsay authored Nov 24, 2023
2 parents a8fb526 + 81df9a5 commit c31eb3e
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 16 deletions.
30 changes: 21 additions & 9 deletions sync3/connmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"sync"
"time"

"golang.org/x/exp/slices"

"github.com/ReneKroon/ttlcache/v2"
"github.com/matrix-org/sliding-sync/internal"
"github.com/prometheus/client_golang/prometheus"
)

Expand All @@ -25,14 +28,14 @@ type ConnMap struct {
mu *sync.Mutex
}

func NewConnMap(enablePrometheus bool) *ConnMap {
func NewConnMap(enablePrometheus bool, ttl time.Duration) *ConnMap {
cm := &ConnMap{
userIDToConn: make(map[string][]*Conn),
connIDToConn: make(map[string]*Conn),
cache: ttlcache.NewCache(),
mu: &sync.Mutex{},
}
cm.cache.SetTTL(30 * time.Minute) // TODO: customisable
cm.cache.SetTTL(ttl)
cm.cache.SetExpirationCallback(cm.closeConnExpires)

if enablePrometheus {
Expand Down Expand Up @@ -132,7 +135,7 @@ func (m *ConnMap) getConn(cid ConnID) *Conn {
}

// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) (*Conn, bool) {
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) *Conn {
// atomically check if a conn exists already and nuke it if it exists
m.mu.Lock()
defer m.mu.Unlock()
Expand All @@ -156,15 +159,19 @@ func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandl
m.connIDToConn[cid.String()] = conn
m.userIDToConn[cid.UserID] = append(m.userIDToConn[cid.UserID], conn)
m.updateMetrics(len(m.connIDToConn))
return conn, true
return conn
}

func (m *ConnMap) CloseConnsForDevice(userID, deviceID string) {
logger.Trace().Str("user", userID).Str("device", deviceID).Msg("closing connections due to CloseConn()")
// gather open connections for this user|device
connIDs := m.connIDsForDevice(userID, deviceID)
for _, cid := range connIDs {
m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn
err := m.cache.Remove(cid.String()) // this will fire TTL callbacks which calls closeConn
if err != nil {
logger.Err(err).Str("cid", cid.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
}
}
}

Expand All @@ -191,7 +198,11 @@ func (m *ConnMap) CloseConnsForUsers(userIDs []string) (closed int) {
logger.Trace().Str("user", userID).Int("num_conns", len(conns)).Msg("closing all device connections due to CloseConn()")

for _, conn := range conns {
m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn
err := m.cache.Remove(conn.String()) // this will fire TTL callbacks which calls closeConn
if err != nil {
logger.Err(err).Str("cid", conn.String()).Msg("CloseConnsForDevice: cid did not exist in ttlcache")
internal.GetSentryHubFromContextOrDefault(context.Background()).CaptureException(err)
}
}
closed += len(conns)
}
Expand Down Expand Up @@ -222,10 +233,11 @@ func (m *ConnMap) closeConn(conn *Conn) {
h := conn.handler
conns := m.userIDToConn[conn.UserID]
for i := 0; i < len(conns); i++ {
if conns[i].DeviceID == conn.DeviceID {
if conns[i].DeviceID == conn.DeviceID && conns[i].CID == conn.CID {
// delete without preserving order
conns[i] = conns[len(conns)-1]
conns = conns[:len(conns)-1]
conns[i] = nil // allow GC
conns = slices.Delete(conns, i, i+1)
i--
}
}
m.userIDToConn[conn.UserID] = conns
Expand Down
229 changes: 229 additions & 0 deletions sync3/connmap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package sync3

import (
"context"
"fmt"
"reflect"
"sort"
"testing"
"time"

"github.com/matrix-org/sliding-sync/sync3/caches"
)

const (
alice = "@alice:localhost"
bob = "@bob:localhost"
)

// mustEqual ensures that got==want else logs an error.
// The 'msg' is displayed with the error to provide extra context.
func mustEqual[V comparable](t *testing.T, got, want V, msg string) {
t.Helper()
if got != want {
t.Errorf("Equal %s: got '%v' want '%v'", msg, got, want)
}
}

func TestConnMap(t *testing.T) {
cm := NewConnMap(false, time.Minute)
cid := ConnID{UserID: alice, DeviceID: "A", CID: "room-list"}
_, cancel := context.WithCancel(context.Background())
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
return &mockConnHandler{}
})
mustEqual(t, conn.ConnID, cid, "cid mismatch")

// lookups work
mustEqual(t, cm.Conn(cid), conn, "*Conn wasn't the same when fetched via Conn(ConnID)")
conns := cm.Conns(cid.UserID, cid.DeviceID)
mustEqual(t, len(conns), 1, "Conns length mismatch")
mustEqual(t, conns[0], conn, "*Conn wasn't the same when fetched via Conns()[0]")
}

func TestConnMap_CloseConnsForDevice(t *testing.T) {
cm := NewConnMap(false, time.Minute)
otherCID := ConnID{UserID: bob, DeviceID: "A", CID: "room-list"}
cidToConn := map[ConnID]*Conn{
{UserID: alice, DeviceID: "A", CID: "room-list"}: nil,
{UserID: alice, DeviceID: "A", CID: "encryption"}: nil,
{UserID: alice, DeviceID: "A", CID: "notifications"}: nil,
{UserID: alice, DeviceID: "B", CID: "room-list"}: nil,
{UserID: alice, DeviceID: "B", CID: "encryption"}: nil,
{UserID: alice, DeviceID: "B", CID: "notifications"}: nil,
otherCID: nil,
}
for cid := range cidToConn {
_, cancel := context.WithCancel(context.Background())
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
return &mockConnHandler{}
})
cidToConn[cid] = conn
}

closedDevice := "A"
cm.CloseConnsForDevice(alice, closedDevice)
time.Sleep(100 * time.Millisecond) // some stuff happens asyncly in goroutines

// Destroy should have been called for all alice|A connections
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
return cid.UserID == alice && cid.DeviceID == "A"
})
}

func TestConnMap_CloseConnsForUser(t *testing.T) {
cm := NewConnMap(false, time.Minute)
otherCID := ConnID{UserID: bob, DeviceID: "A", CID: "room-list"}
cidToConn := map[ConnID]*Conn{
{UserID: alice, DeviceID: "A", CID: "room-list"}: nil,
{UserID: alice, DeviceID: "A", CID: "encryption"}: nil,
{UserID: alice, DeviceID: "A", CID: "notifications"}: nil,
{UserID: alice, DeviceID: "B", CID: "room-list"}: nil,
{UserID: alice, DeviceID: "B", CID: "encryption"}: nil,
{UserID: alice, DeviceID: "B", CID: "notifications"}: nil,
otherCID: nil,
}
for cid := range cidToConn {
_, cancel := context.WithCancel(context.Background())
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
return &mockConnHandler{}
})
cidToConn[cid] = conn
}

num := cm.CloseConnsForUsers([]string{alice})
time.Sleep(100 * time.Millisecond) // some stuff happens asyncly in goroutines
mustEqual(t, num, 6, "unexpected number of closed conns")

// Destroy should have been called for all alice connections
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
return cid.UserID == alice
})
}

func TestConnMap_TTLExpiry(t *testing.T) {
cm := NewConnMap(false, time.Second) // 1s expiry
expiredCIDs := []ConnID{
{UserID: alice, DeviceID: "A", CID: "room-list"},
{UserID: alice, DeviceID: "A", CID: "encryption"},
{UserID: alice, DeviceID: "A", CID: "notifications"},
}
cidToConn := map[ConnID]*Conn{}
for _, cid := range expiredCIDs {
_, cancel := context.WithCancel(context.Background())
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
return &mockConnHandler{}
})
cidToConn[cid] = conn
}
time.Sleep(time.Millisecond * 500)

unexpiredCIDs := []ConnID{
{UserID: alice, DeviceID: "B", CID: "room-list"},
{UserID: alice, DeviceID: "B", CID: "encryption"},
{UserID: alice, DeviceID: "B", CID: "notifications"},
}
for _, cid := range unexpiredCIDs {
_, cancel := context.WithCancel(context.Background())
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
return &mockConnHandler{}
})
cidToConn[cid] = conn
}

time.Sleep(510 * time.Millisecond) // all 'A' device conns must have expired

// Destroy should have been called for all alice|A connections
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
return cid.DeviceID == "A"
})
}

func TestConnMap_TTLExpiryStaggeredDevices(t *testing.T) {
cm := NewConnMap(false, time.Second) // 1s expiry
expiredCIDs := []ConnID{
{UserID: alice, DeviceID: "A", CID: "room-list"},
{UserID: alice, DeviceID: "B", CID: "encryption"},
{UserID: alice, DeviceID: "B", CID: "notifications"},
}
cidToConn := map[ConnID]*Conn{}
for _, cid := range expiredCIDs {
_, cancel := context.WithCancel(context.Background())
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
return &mockConnHandler{}
})
cidToConn[cid] = conn
}
time.Sleep(time.Millisecond * 500)

unexpiredCIDs := []ConnID{
{UserID: alice, DeviceID: "B", CID: "room-list"},
{UserID: alice, DeviceID: "A", CID: "encryption"},
{UserID: alice, DeviceID: "A", CID: "notifications"},
}
for _, cid := range unexpiredCIDs {
_, cancel := context.WithCancel(context.Background())
conn := cm.CreateConn(cid, cancel, func() ConnHandler {
return &mockConnHandler{}
})
cidToConn[cid] = conn
}

// all expiredCIDs should have expired, none from unexpiredCIDs
time.Sleep(510 * time.Millisecond)

// Destroy should have been called for all expiredCIDs connections
assertDestroyedConns(t, cidToConn, func(cid ConnID) bool {
for _, expCID := range expiredCIDs {
if expCID.String() == cid.String() {
return true
}
}
return false
})

// double check this by querying connmap
conns := cm.Conns(alice, "A")
var gotIDs []string
for _, c := range conns {
t.Logf(c.String())
gotIDs = append(gotIDs, c.CID)
}
sort.Strings(gotIDs)
wantIDs := []string{"encryption", "notifications"}
mustEqual(t, len(conns), 2, "unexpected number of Conns for device")
if !reflect.DeepEqual(gotIDs, wantIDs) {
t.Fatalf("unexpected active conns: got %v want %v", gotIDs, wantIDs)
}
}

func assertDestroyedConns(t *testing.T, cidToConn map[ConnID]*Conn, isDestroyedFn func(cid ConnID) bool) {
t.Helper()
for cid, conn := range cidToConn {
if isDestroyedFn(cid) {
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, true, fmt.Sprintf("conn %+v was not destroyed", cid))
} else {
mustEqual(t, conn.handler.(*mockConnHandler).isDestroyed, false, fmt.Sprintf("conn %+v was destroyed", cid))
}
}
}

type mockConnHandler struct {
isDestroyed bool
cancel context.CancelFunc
}

func (c *mockConnHandler) OnIncomingRequest(ctx context.Context, cid ConnID, req *Request, isInitial bool, start time.Time) (*Response, error) {
return nil, nil
}
func (c *mockConnHandler) OnUpdate(ctx context.Context, update caches.Update) {}
func (c *mockConnHandler) PublishEventsUpTo(roomID string, nid int64) {}
func (c *mockConnHandler) Destroy() {
c.isDestroyed = true
}
func (c *mockConnHandler) Alive() bool {
return true // buffer never fills up
}
func (c *mockConnHandler) SetCancelCallback(cancel context.CancelFunc) {
c.cancel = cancel
}
10 changes: 3 additions & 7 deletions sync3/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewSync3Handler(
V2: v2Client,
Storage: store,
V2Store: storev2,
ConnMap: sync3.NewConnMap(enablePrometheus),
ConnMap: sync3.NewConnMap(enablePrometheus, 30*time.Minute),
userCaches: &sync.Map{},
Dispatcher: sync3.NewDispatcher(),
GlobalCache: caches.NewGlobalCache(store),
Expand Down Expand Up @@ -453,14 +453,10 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, cancel context.Canc
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection.
conn, created := h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
conn = h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
})
if created {
log.Info().Msg("created new connection")
} else {
log.Info().Msg("using existing connection")
}
log.Info().Msg("created new connection")
return req, conn, nil
}

Expand Down

0 comments on commit c31eb3e

Please sign in to comment.