Skip to content

Commit

Permalink
Merge pull request #422 from matrix-org/kegan/db-size
Browse files Browse the repository at this point in the history
Clean the syncv3_snapshots table periodically
  • Loading branch information
kegsay authored Apr 24, 2024
2 parents b0ab21e + a1b5e48 commit abf5aef
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmd/syncv3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ func main() {
})

go h2.StartV2Pollers()
go h2.Store.Cleaner(time.Hour)
if args[EnvOTLP] != "" {
h3 = otelhttp.NewHandler(h3, "Sync")
}
Expand Down
81 changes: 81 additions & 0 deletions state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"os"
"strings"
"time"

"golang.org/x/exp/slices"

Expand Down Expand Up @@ -69,6 +70,8 @@ type Storage struct {
ReceiptTable *ReceiptTable
DB *sqlx.DB
MaxTimelineLimit int
shutdownCh chan struct{}
shutdown bool
}

func NewStorage(postgresURI string) *Storage {
Expand Down Expand Up @@ -104,6 +107,7 @@ func NewStorageWithDB(db *sqlx.DB, addPrometheusMetrics bool) *Storage {
ReceiptTable: NewReceiptTable(db),
DB: db,
MaxTimelineLimit: 50,
shutdownCh: make(chan struct{}),
}
}

Expand Down Expand Up @@ -758,6 +762,50 @@ func (s *Storage) LatestEventsInRooms(userID string, roomIDs []string, to int64,
return result, err
}

// Remove state snapshots which cannot be accessed by clients. The latest MaxTimelineEvents
// snapshots must be kept, +1 for the current state. This handles the worst case where all
// MaxTimelineEvents are state events and hence each event makes a new snapshot. We can safely
// delete all snapshots older than this, as it's not possible to reach this snapshot as the proxy
// does not handle historical state (deferring to the homeserver for that).
func (s *Storage) RemoveInaccessibleStateSnapshots() error {
numToKeep := s.MaxTimelineLimit + 1
// Create a CTE which ranks each snapshot so we can figure out which snapshots to delete
// then execute the delete using the CTE.
//
// A per-room version of this query:
// WITH ranked_snapshots AS (
// SELECT
// snapshot_id,
// room_id,
// ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num
// FROM syncv3_snapshots
// )
// DELETE FROM syncv3_snapshots WHERE snapshot_id IN(
// SELECT snapshot_id FROM ranked_snapshots WHERE row_num > 51 AND room_id='!....'
// );
awfulQuery := fmt.Sprintf(`WITH ranked_snapshots AS (
SELECT
snapshot_id,
room_id,
ROW_NUMBER() OVER (PARTITION BY room_id ORDER BY snapshot_id DESC) AS row_num
FROM
syncv3_snapshots
)
DELETE FROM syncv3_snapshots USING ranked_snapshots
WHERE syncv3_snapshots.snapshot_id = ranked_snapshots.snapshot_id
AND ranked_snapshots.row_num > %d;`, numToKeep)

result, err := s.DB.Exec(awfulQuery)
if err != nil {
return fmt.Errorf("failed to RemoveInaccessibleStateSnapshots: Exec %s", err)
}
rowsAffected, err := result.RowsAffected()
if err == nil {
logger.Info().Int64("rows_affected", rowsAffected).Msg("RemoveInaccessibleStateSnapshots: deleted rows")
}
return nil
}

func (s *Storage) GetClosestPrevBatch(roomID string, eventNID int64) (prevBatch string) {
var err error
sqlutil.WithTransaction(s.DB, func(txn *sqlx.Tx) error {
Expand Down Expand Up @@ -1024,6 +1072,34 @@ func (s *Storage) AllJoinedMembers(txn *sqlx.Tx, tempTableName string) (joinedMe
return joinedMembers, metadata, nil
}

func (s *Storage) Cleaner(n time.Duration) {
Loop:
for {
select {
case <-time.After(n):
now := time.Now()
boundaryTime := now.Add(-1 * n)
if n < time.Hour {
boundaryTime = now.Add(-1 * time.Hour)
}
logger.Info().Time("boundaryTime", boundaryTime).Msg("Cleaner running")
err := s.TransactionsTable.Clean(boundaryTime)
if err != nil {
logger.Warn().Err(err).Msg("failed to clean txn ID table")
sentry.CaptureException(err)
}
// we also want to clean up stale state snapshots which are inaccessible, to
// keep the size of the syncv3_snapshots table low.
if err = s.RemoveInaccessibleStateSnapshots(); err != nil {
logger.Warn().Err(err).Msg("failed to remove inaccessible state snapshots")
sentry.CaptureException(err)
}
case <-s.shutdownCh:
break Loop
}
}
}

func (s *Storage) LatestEventNIDInRooms(roomIDs []string, highestNID int64) (roomToNID map[string]int64, err error) {
roomToNID = make(map[string]int64)
err = sqlutil.WithTransaction(s.Accumulator.db, func(txn *sqlx.Tx) error {
Expand Down Expand Up @@ -1113,6 +1189,11 @@ func (s *Storage) determineJoinedRoomsFromMemberships(membershipEvents []Event)
}

func (s *Storage) Teardown() {
if !s.shutdown {
s.shutdown = true
close(s.shutdownCh)
}

err := s.Accumulator.db.Close()
if err != nil {
panic("Storage.Teardown: " + err.Error())
Expand Down
170 changes: 170 additions & 0 deletions state/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"context"
"encoding/json"
"fmt"
"math/rand"
"reflect"
"sort"
"testing"
"time"

"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/sliding-sync/sync2"

"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -913,6 +915,174 @@ func TestStorage_FetchMemberships(t *testing.T) {
assertValue(t, "joins", leaves, []string{"@chris:test", "@david:test", "@glory:test", "@helen:test"})
}

type persistOpts struct {
withInitialEvents bool
numTimelineEvents int
ofWhichNumState int
}

func mustPersistEvents(t *testing.T, roomID string, store *Storage, opts persistOpts) {
t.Helper()
var events []json.RawMessage
if opts.withInitialEvents {
events = createInitialEvents(t, userID)
}
numAddedStateEvents := 0
for i := 0; i < opts.numTimelineEvents; i++ {
var ev json.RawMessage
if numAddedStateEvents < opts.ofWhichNumState {
numAddedStateEvents++
ev = testutils.NewStateEvent(t, "some_kind_of_state", fmt.Sprintf("%d", rand.Int63()), userID, map[string]interface{}{
"num": numAddedStateEvents,
})
} else {
ev = testutils.NewEvent(t, "some_kind_of_message", userID, map[string]interface{}{
"msg": "yep",
})
}
events = append(events, ev)
}
mustAccumulate(t, store, roomID, events)
}

func mustAccumulate(t *testing.T, store *Storage, roomID string, events []json.RawMessage) {
t.Helper()
_, err := store.Accumulate(userID, roomID, sync2.TimelineResponse{
Events: events,
})
if err != nil {
t.Fatalf("Failed to accumulate: %s", err)
}
}

func mustHaveNumSnapshots(t *testing.T, db *sqlx.DB, roomID string, numSnapshots int) {
t.Helper()
var val int
err := db.QueryRow(`SELECT count(*) FROM syncv3_snapshots WHERE room_id=$1`, roomID).Scan(&val)
if err != nil {
t.Fatalf("mustHaveNumSnapshots: %s", err)
}
if val != numSnapshots {
t.Fatalf("mustHaveNumSnapshots: got %d want %d snapshots", val, numSnapshots)
}
}

func mustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("err: %s", err)
}

func TestRemoveInaccessibleStateSnapshots(t *testing.T) {
store := NewStorage(postgresConnectionString)
store.MaxTimelineLimit = 50 // we nuke if we have >50+1 snapshots

roomOnlyMessages := "!TestRemoveInaccessibleStateSnapshots_roomOnlyMessages:localhost"
mustPersistEvents(t, roomOnlyMessages, store, persistOpts{
withInitialEvents: true,
numTimelineEvents: 100,
ofWhichNumState: 0,
})
roomOnlyState := "!TestRemoveInaccessibleStateSnapshots_roomOnlyState:localhost"
mustPersistEvents(t, roomOnlyState, store, persistOpts{
withInitialEvents: true,
numTimelineEvents: 100,
ofWhichNumState: 100,
})
roomPartialStateAndMessages := "!TestRemoveInaccessibleStateSnapshots_roomPartialStateAndMessages:localhost"
mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{
withInitialEvents: true,
numTimelineEvents: 100,
ofWhichNumState: 30,
})
roomOverwriteState := "TestRemoveInaccessibleStateSnapshots_roomOverwriteState:localhost"
mustPersistEvents(t, roomOverwriteState, store, persistOpts{
withInitialEvents: true,
})
mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 1})})
mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 2})})
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4) // initial state only, one for each state event
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 104) // initial state + 100 state events
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34) // initial state + 30 state events
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // initial state + 2 overwrite state events
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4) // it should not be touched as 4 < 51
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) // it should be capped at 51
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34) // it should not be touched as 34 < 51
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // it should not be touched as 6 < 51
// calling it again does nothing
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 4)
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51)
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 34)
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 6) // it should not be touched as 6 < 51
// adding one extra state snapshot to each room and repeating RemoveInaccessibleStateSnapshots
mustPersistEvents(t, roomOnlyMessages, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1})
mustPersistEvents(t, roomOnlyState, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1})
mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{numTimelineEvents: 1, ofWhichNumState: 1})
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 5)
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51) // still capped
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 35)
// adding 51 timeline events and repeating RemoveInaccessibleStateSnapshots does nothing
mustPersistEvents(t, roomOnlyMessages, store, persistOpts{numTimelineEvents: 51})
mustPersistEvents(t, roomOnlyState, store, persistOpts{numTimelineEvents: 51})
mustPersistEvents(t, roomPartialStateAndMessages, store, persistOpts{numTimelineEvents: 51})
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
mustHaveNumSnapshots(t, store.DB, roomOnlyMessages, 5)
mustHaveNumSnapshots(t, store.DB, roomOnlyState, 51)
mustHaveNumSnapshots(t, store.DB, roomPartialStateAndMessages, 35)

// overwrite 52 times and check the current state is 52 (shows we are keeping the right snapshots)
for i := 0; i < 52; i++ {
mustAccumulate(t, store, roomOverwriteState, []json.RawMessage{testutils.NewStateEvent(t, "overwrite", "", userID, map[string]interface{}{"val": 1 + i})})
}
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 58)
mustNotError(t, store.RemoveInaccessibleStateSnapshots())
mustHaveNumSnapshots(t, store.DB, roomOverwriteState, 51)
roomsTable := NewRoomsTable(store.DB)
mustNotError(t, sqlutil.WithTransaction(store.DB, func(txn *sqlx.Tx) error {
snapID, err := roomsTable.CurrentAfterSnapshotID(txn, roomOverwriteState)
if err != nil {
return err
}
state, err := store.StateSnapshot(snapID)
if err != nil {
return err
}
// find the 'overwrite' event and make sure the val is 52
for _, ev := range state {
evv := gjson.ParseBytes(ev)
if evv.Get("type").Str != "overwrite" {
continue
}
if evv.Get("content.val").Int() != 52 {
return fmt.Errorf("val for overwrite state event was not 52: %v", evv.Raw)
}
}
return nil
}))
}

func createInitialEvents(t *testing.T, creator string) []json.RawMessage {
t.Helper()
baseTimestamp := time.Now()
var pl gomatrixserverlib.PowerLevelContent
pl.Defaults()
pl.Users = map[string]int64{
creator: 100,
}
// all with the same timestamp as they get made atomically
return []json.RawMessage{
testutils.NewStateEvent(t, "m.room.create", "", creator, map[string]interface{}{"creator": creator}, testutils.WithTimestamp(baseTimestamp)),
testutils.NewJoinEvent(t, creator, testutils.WithTimestamp(baseTimestamp)),
testutils.NewStateEvent(t, "m.room.power_levels", "", creator, pl, testutils.WithTimestamp(baseTimestamp)),
testutils.NewStateEvent(t, "m.room.join_rules", "", creator, map[string]interface{}{"join_rule": "public"}, testutils.WithTimestamp(baseTimestamp)),
}
}

func cleanDB(t *testing.T) error {
// make a fresh DB which is unpolluted from other tests
db, close := connectToDB(t)
Expand Down

0 comments on commit abf5aef

Please sign in to comment.