Skip to content

Commit

Permalink
Merge pull request #432 from matrix-org/kegan/device-list-updates
Browse files Browse the repository at this point in the history
Ensure device list updates are robust to race conditions and network failures
  • Loading branch information
kegsay authored May 10, 2024
2 parents 150d9d6 + f564f2d commit 0d22cf1
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 75 deletions.
7 changes: 5 additions & 2 deletions state/device_data_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
if err != nil {
if err == sql.ErrNoRows {
// if there is no device data for this user, it's not an error.
Expand All @@ -70,6 +70,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
if !swap {
return nil // don't swap
}
// the caller will only look at sent, so make sure what is new is now in sent
result.DeviceLists.Sent = result.DeviceLists.New

// swap over the fields
writeBack := *result
writeBack.DeviceLists.Sent = result.DeviceLists.New
Expand Down Expand Up @@ -104,7 +107,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
// select what already exists
var row DeviceDataRow
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
if err != nil && err != sql.ErrNoRows {
return err
}
Expand Down
220 changes: 147 additions & 73 deletions state/device_data_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
assertVal(t, "DeviceLists", g.DeviceLists, w.DeviceLists)
if w.DeviceLists.Sent != nil {
assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent)
}
}

func TestDeviceDataTableSwaps(t *testing.T) {
// Tests OTKCounts and FallbackKeyTypes behaviour
func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
db, close := connectToDB(t)
defer close()
table := NewDeviceDataTable(db)
userID := "@bob"
userID := "@TestDeviceDataTableOTKCountAndFallbackKeyTypes"
deviceID := "BOB"

// test accumulating deltas
// these are individual updates from Synapse from /sync v2
deltas := []internal.DeviceData{
{
UserID: userID,
Expand All @@ -46,9 +49,6 @@ func TestDeviceDataTableSwaps(t *testing.T) {
UserID: userID,
DeviceID: deviceID,
FallbackKeyTypes: []string{"foobar"},
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
},
},
{
UserID: userID,
Expand All @@ -60,85 +60,157 @@ func TestDeviceDataTableSwaps(t *testing.T) {
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
},
},
}

// apply them
for _, dd := range deltas {
err := table.Upsert(&dd)
assertNoError(t, err)
}

// read them without swap, it should have replaced them correctly.
// Because sync v2 sends the complete OTK count and complete fallback key types
// every time, we always use the latest values. Because we aren't swapping, repeated
// reads produce the same result.
for i := 0; i < 3; i++ {
got, err := table.Select(userID, deviceID, false)
mustNotError(t, err)
want := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
assertDeviceData(t, *got, want)
}
// now we swap the data. This still returns the same values, but the changed bits are no longer set
// on subsequent reads.
got, err := table.Select(userID, deviceID, true)
mustNotError(t, err)
want := internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
Sent: map[string]int{},
},
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
// check we can read-only select
assertDeviceData(t, *got, want)

// subsequent read
got, err = table.Select(userID, deviceID, false)
mustNotError(t, err)
want = internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
OTKCounts: map[string]int{
"foo": 99,
},
FallbackKeyTypes: []string{"foobar"},
}
assertDeviceData(t, *got, want)
}

// Tests the DeviceLists field
func TestDeviceDataTableDeviceList(t *testing.T) {
db, close := connectToDB(t)
defer close()
table := NewDeviceDataTable(db)
userID := "@TestDeviceDataTableDeviceList"
deviceID := "BOB"

// these are individual updates from Synapse from /sync v2
deltas := []internal.DeviceData{
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
},
},
{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
},
},
}
// apply them
for _, dd := range deltas {
err := table.Upsert(&dd)
assertNoError(t, err)
}

// check we can read-only select. This doesn't modify any fields.
for i := 0; i < 3; i++ {
got, err := table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{}, // until we "swap" we don't consume the New entries
},
})
}
// now swap-er-roo, at this point we still expect the "old" data,
// as it is the first time we swap
// now swap-er-roo, which shifts everything from New into Sent.
got, err := table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, want)

// changed bits were reset when we swapped
want2 := want
want2.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: map[string]int{},
}
want2.ChangedBits = 0
want.ChangedBits = 0
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})

// this is permanent, read-only views show this too.
// Since we have swapped previously, we now expect New to be empty
// and Sent to be set. Swap again to clear Sent.
got, err = table.Select(userID, deviceID, true)
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want2)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})

// We now expect empty DeviceLists, as we swapped twice.
got, err = table.Select(userID, deviceID, false)
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want3 := want2
want3.DeviceLists = internal.DeviceLists{
Sent: map[string]int{},
New: map[string]int{},
}
assertDeviceData(t, *got, want3)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{},
},
})

// get back the original state
//err = table.DeleteDevice(userID, deviceID)
assertNoError(t, err)
for _, dd := range deltas {
err = table.Upsert(&dd)
assertNoError(t, err)
}
want.SetFallbackKeysChanged()
want.SetOTKCountChanged()
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want)

// swap once then add once so both sent and new are populated
// Moves Alice and Bob to Sent
_, err = table.Select(userID, deviceID, true)
// Move original state to Sent by swapping
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})
// Add new entries to New before acknowledging Sent
err = table.Upsert(&internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
Expand All @@ -148,20 +220,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
})
assertNoError(t, err)

want.ChangedBits = 0

want4 := want
want4.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
}
// Without swapping, we expect Alice and Bob in Sent, and Bob and Charlie in New
// Reading without swapping does not move New->Sent, so returns the previous value
got, err = table.Select(userID, deviceID, false)
assertNoError(t, err)
assertDeviceData(t, *got, want4)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
},
})

// another append then consume
// This results in dave to be added to New
// Append even more items to New
err = table.Upsert(&internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
Expand All @@ -170,24 +240,28 @@ func TestDeviceDataTableSwaps(t *testing.T) {
},
})
assertNoError(t, err)

// Now swap: all the combined items in New go into Sent
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want5 := want4
want5.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
}
assertDeviceData(t, *got, want5)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"💣", "dave"}, []string{"charlie", "dave"}),
},
})

// Swapping again clears New
// Swapping again clears Sent out, and since nothing is in New we get an empty list
got, err = table.Select(userID, deviceID, true)
assertNoError(t, err)
want5 = want4
want5.DeviceLists = internal.DeviceLists{
Sent: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
New: map[string]int{},
}
assertDeviceData(t, *got, want5)
assertDeviceData(t, *got, internal.DeviceData{
UserID: userID,
DeviceID: deviceID,
DeviceLists: internal.DeviceLists{
Sent: internal.MapStringInt{},
},
})

// delete everything, no data returned
assertNoError(t, table.DeleteDevice(userID, deviceID))
Expand Down
42 changes: 42 additions & 0 deletions tests-integration/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,48 @@ func TestExtensionE2EE(t *testing.T) {
if time.Since(start) >= (500 * time.Millisecond) {
t.Fatalf("sync request did not return immediately with OTK counts")
}

// check that if we lose a device list update and restart from nothing, we see the same update
v2.queueResponse(alice, sync2.SyncResponse{
DeviceLists: struct {
Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"`
}{
Changed: wantChanged,
Left: wantLeft,
},
})
v2.waitUntilEmpty(t, alice)
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
// enable the E2EE extension
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))
// we actually lost this update: start again and we should see it.
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
Lists: map[string]sync3.RequestList{"a": {
Ranges: sync3.SliceRanges{
[2]int64{0, 10}, // doesn't matter
},
}},
// enable the E2EE extension
Extensions: extensions.Request{
E2EE: &extensions.E2EERequest{
Core: extensions.Core{Enabled: &boolTrue},
},
},
})
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))

}

// Checks that to-device messages are passed from v2 to v3
Expand Down

0 comments on commit 0d22cf1

Please sign in to comment.