diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go index 689949ca00..2f33589fe7 100644 --- a/userapi/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -181,21 +181,7 @@ func (u *DeviceListUpdater) Start() error { return err } - // Filter out dupe domains, as processServer is going to get all users anyway - seenDomains := make(map[spec.ServerName]struct{}) - newStaleLists := make([]string, 0, len(staleLists)) - for _, userID := range staleLists { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - // non-fatal and should not block starting up - continue - } - if _, ok := seenDomains[domain]; ok { - continue - } - newStaleLists = append(newStaleLists, userID) - seenDomains[domain] = struct{}{} - } + newStaleLists := dedupeStaleLists(staleLists) offset, step := time.Second*10, time.Second if max := len(newStaleLists); max > 120 { step = (time.Second * 120) / time.Duration(max) @@ -599,3 +585,24 @@ func (u *DeviceListUpdater) updateDeviceList(res *fclient.RespUserDevices) error } return nil } + +// dedupeStaleLists de-duplicates the stateList entries using the domain. +// This is used on startup, processServer is getting all users anyway, so +// there is no need to send every user to the workers. +func dedupeStaleLists(staleLists []string) []string { + seenDomains := make(map[spec.ServerName]struct{}) + newStaleLists := make([]string, 0, len(staleLists)) + for _, userID := range staleLists { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + // non-fatal and should not block starting up + continue + } + if _, ok := seenDomains[domain]; ok { + continue + } + newStaleLists = append(newStaleLists, userID) + seenDomains[domain] = struct{}{} + } + return newStaleLists +} diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go index 10b9c6521f..1b9e4b5c34 100644 --- a/userapi/internal/device_list_update_test.go +++ b/userapi/internal/device_list_update_test.go @@ -428,3 +428,43 @@ func TestDeviceListUpdater_CleanUp(t *testing.T) { } }) } + +func Test_dedupeStateList(t *testing.T) { + alice := "@alice:localhost" + bob := "@bob:localhost" + charlie := "@charlie:notlocalhost" + + tests := []struct { + name string + staleLists []string + want []string + }{ + { + name: "empty stateLists", + staleLists: []string{}, + want: []string{}, + }, + { + name: "single entry", + staleLists: []string{alice}, + want: []string{alice}, + }, + { + name: "multiple entries without dupe servers", + staleLists: []string{alice, charlie}, + want: []string{alice, charlie}, + }, + { + name: "multiple entries with dupe servers", + staleLists: []string{alice, bob, charlie}, + want: []string{alice, charlie}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dedupeStaleLists(tt.staleLists); !reflect.DeepEqual(got, tt.want) { + t.Errorf("dedupeStaleLists() = %v, want %v", got, tt.want) + } + }) + } +}