From f595aed2c526fd8e94828e60ada29ab403597fda Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 1 Nov 2023 18:36:50 +0000 Subject: [PATCH] Add a separate payload for redacting state So that we don't end up nuking conns unnecessarily. --- pubsub/v2.go | 15 ++++++++++++++- state/accumulator.go | 9 ++++----- state/accumulator_test.go | 6 +++--- sync2/handler2/handler.go | 6 +++--- sync3/handler/handler.go | 8 ++++++++ 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/pubsub/v2.go b/pubsub/v2.go index 4a886167..317e96cd 100644 --- a/pubsub/v2.go +++ b/pubsub/v2.go @@ -25,6 +25,7 @@ type V2Listener interface { OnDeviceMessages(p *V2DeviceMessages) OnExpiredToken(p *V2ExpiredToken) OnInvalidateRoom(p *V2InvalidateRoom) + OnStateRedaction(p *V2StateRedaction) } type V2Initialise struct { @@ -130,7 +131,17 @@ type V2ExpiredToken struct { func (*V2ExpiredToken) Type() string { return "V2ExpiredToken" } -// V2InvalidateRoom is emitted after a non-incremental state change to a room. +// V2StateRedaction is emitted when a timeline is seen that contains one or more +// redaction events targeting a piece of room state. The redaction will be emitted +// before its corresponding V2Accumulate payload is emitted. +type V2StateRedaction struct { + RoomID string +} + +func (*V2StateRedaction) Type() string { return "V2StateRedaction" } + +// V2InvalidateRoom is emitted after a non-incremental state change to a room, in place +// of a V2Initialise payload. type V2InvalidateRoom struct { RoomID string } @@ -183,6 +194,8 @@ func (v *V2Sub) onMessage(p Payload) { v.receiver.OnExpiredToken(pl) case *V2InvalidateRoom: v.receiver.OnInvalidateRoom(pl) + case *V2StateRedaction: + v.receiver.OnStateRedaction(pl) default: logger.Warn().Str("type", p.Type()).Msg("V2Sub: unhandled payload type") } diff --git a/state/accumulator.go b/state/accumulator.go index ae5fe2ee..bda2640f 100644 --- a/state/accumulator.go +++ b/state/accumulator.go @@ -326,10 +326,9 @@ type AccumulateResult struct { // TimelineNIDs is the list of event nids seen in a sync v2 timeline. Some of these // may already be known to the proxy. TimelineNIDs []int64 - // RequiresReload is set to true when we have accumulated a non-incremental state - // change (typically a redaction) that requires consumers to reload the room state - // from the latest snapshot. - RequiresReload bool + // IncludesStateRedaction is set to true when we have accumulated a redaction to a + // piece of room state. + IncludesStateRedaction bool } // Accumulate internal state from a user's sync response. The timeline order MUST be in the order @@ -546,7 +545,7 @@ func (a *Accumulator) Accumulate(txn *sqlx.Tx, userID, roomID string, timeline s if err != nil { return AccumulateResult{}, err } - result.RequiresReload = currentStateRedactions > 0 + result.IncludesStateRedaction = currentStateRedactions > 0 } if err = a.invitesTable.RemoveSupersededInvites(txn, roomID, postInsertEvents); err != nil { diff --git a/state/accumulator_test.go b/state/accumulator_test.go index 0a73febb..65dc2440 100644 --- a/state/accumulator_test.go +++ b/state/accumulator_test.go @@ -256,7 +256,7 @@ func TestAccumulatorPromptsCacheInvalidation(t *testing.T) { t.Log("We expect 3 new events and no reload required.") assertValue(t, "accResult.NumNew", accResult.NumNew, 3) assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 3) - assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false) + assertValue(t, "accResult.IncludesStateRedaction", accResult.IncludesStateRedaction, false) t.Log("Redact the old state event and the message.") timeline = []json.RawMessage{ @@ -274,7 +274,7 @@ func TestAccumulatorPromptsCacheInvalidation(t *testing.T) { t.Log("We expect 2 new events and no reload required.") assertValue(t, "accResult.NumNew", accResult.NumNew, 2) assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 2) - assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, false) + assertValue(t, "accResult.IncludesStateRedaction", accResult.IncludesStateRedaction, false) t.Log("Redact the latest state event.") timeline = []json.RawMessage{ @@ -291,7 +291,7 @@ func TestAccumulatorPromptsCacheInvalidation(t *testing.T) { t.Log("We expect 1 new event and a reload required.") assertValue(t, "accResult.NumNew", accResult.NumNew, 1) assertValue(t, "len(accResult.TimelineNIDs)", len(accResult.TimelineNIDs), 1) - assertValue(t, "accResult.RequiresReload", accResult.RequiresReload, true) + assertValue(t, "accResult.IncludesStateRedaction", accResult.IncludesStateRedaction, true) } func TestAccumulatorMembershipLogs(t *testing.T) { diff --git a/sync2/handler2/handler.go b/sync2/handler2/handler.go index 3fea03de..66ccf4c2 100644 --- a/sync2/handler2/handler.go +++ b/sync2/handler2/handler.go @@ -302,9 +302,9 @@ func (h *Handler) Accumulate(ctx context.Context, userID, deviceID, roomID strin return err } - // Consumers should reload state before processing new timeline events. - if accResult.RequiresReload { - h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2InvalidateRoom{ + // Consumers should reload state content before processing new timeline events. + if accResult.IncludesStateRedaction { + h.v2Pub.Notify(pubsub.ChanV2, &pubsub.V2StateRedaction{ RoomID: roomID, }) } diff --git a/sync3/handler/handler.go b/sync3/handler/handler.go index 5e6b80e3..45f9ea7f 100644 --- a/sync3/handler/handler.go +++ b/sync3/handler/handler.go @@ -803,6 +803,14 @@ func (h *SyncLiveHandler) OnExpiredToken(p *pubsub.V2ExpiredToken) { h.ConnMap.CloseConnsForDevice(p.UserID, p.DeviceID) } +func (h *SyncLiveHandler) OnStateRedaction(p *pubsub.V2StateRedaction) { + // We only need to reload the global metadata here: mercifully, there isn't anything + // in the user cache that needs to be reloaded after state gets redacted. + ctx, task := internal.StartTask(context.Background(), "OnStateRedaction") + defer task.End() + h.GlobalCache.OnInvalidateRoom(ctx, p.RoomID) +} + func (h *SyncLiveHandler) OnInvalidateRoom(p *pubsub.V2InvalidateRoom) { ctx, task := internal.StartTask(context.Background(), "OnInvalidateRoom") defer task.End()