From db83789654ade3cf4f900e8fbcaa742b60c5dc6c Mon Sep 17 00:00:00 2001 From: devonh Date: Fri, 15 Sep 2023 15:25:09 +0000 Subject: [PATCH] Move pseudoID ClientEvent hotswapping to a common location (#3199) Fixes a variety of issues where clients were receiving pseudoIDs in places that should be userIDs. This change makes pseudoIDs work with sliding sync & element x. --------- Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> --- clientapi/routing/state.go | 21 +- internal/eventutil/events.go | 6 +- syncapi/routing/getevent.go | 18 +- syncapi/routing/relations.go | 21 +- syncapi/routing/search.go | 20 +- syncapi/streams/stream_invite.go | 15 +- syncapi/streams/stream_pdu.go | 131 --------- syncapi/synctypes/clientevent.go | 365 ++++++++++++++++++++------ syncapi/synctypes/clientevent_test.go | 30 ++- syncapi/types/types.go | 35 ++- syncapi/types/types_test.go | 29 +- userapi/consumers/roomserver.go | 43 +-- userapi/util/notify_test.go | 15 +- 13 files changed, 421 insertions(+), 328 deletions(-) diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 6f363349b6..18f9a0e9cd 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -172,23 +172,16 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } } for _, ev := range stateAfterRes.StateEvents { - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := ev.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } + clientEvent, err := synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed converting to ClientEvent") + continue } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll, sender.String(), sk, ev.Unsigned()), + *clientEvent, ) } } diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index b3523e1296..40d62fd68e 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -176,11 +176,13 @@ func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserv return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type()) } redactedEvent.Redact() - senderID, err := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactionEvent.SenderID()) + clientEvent, err := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return querier.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return err } - redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, senderID.String(), redactionEvent.StateKey(), redactionEvent.Unsigned()) + redactedBecause := clientEvent if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil { return err } diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 886b116758..c089539f05 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -118,25 +118,19 @@ func GetEvent( } } - senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, events[0].SenderID()) - if err != nil || senderUserID == nil { - util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("QueryUserIDForSender errored or returned nil-user ID when user should be part of a room") + clientEvent, err := synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent") return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: spec.Unknown("internal server error"), } } - sk := events[0].StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, senderUserID.String(), sk, events[0].Unsigned()), + JSON: *clientEvent, } } diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index b451a7e2e8..935ba83b30 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -130,23 +130,16 @@ func Relations( // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID()) - if err == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } + clientEvent, err := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent") + continue } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender.String(), sk, event.Unsigned()), + *clientEvent, ) } diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index f574781aa8..4a8be9f495 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -230,20 +230,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos[userID.String()] = profile } - sender := spec.UserID{} - userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID + clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if err != nil { + util.GetLogger(req.Context()).WithError(err).WithField("senderID", event.SenderID()).Error("Failed converting to ClientEvent") + continue } - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } results = append(results, Result{ Context: SearchContextResponse{ Start: startToken.String(), @@ -257,7 +251,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()), + Result: *clientEvent, }) roomGroup := groups[event.RoomID().String()] roomGroup.Results = append(roomGroup.Results, event.EventID()) diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 1424dc2e61..a3634c03fe 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -75,20 +75,15 @@ func (p *InviteStreamProvider) IncrementalSync( user = *sender } - sk := inviteEvent.StateKey() - if sk != nil && *sk != "" { - skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey())) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } - // skip ignored user events if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent, user, sk, eventFormat) + ir, err := types.NewInviteResponse(ctx, p.rsAPI, inviteEvent, eventFormat) + if err != nil { + req.Log.WithError(err).Error("failed creating invite response") + continue + } req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index eb1f0ef2e7..3abb0b3c6d 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -3,7 +3,6 @@ package streams import ( "context" "database/sql" - "encoding/json" "fmt" "time" @@ -16,8 +15,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib/spec" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" @@ -359,23 +356,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // Now that we've filtered the timeline, work out which state events are still // left. Anything that appears in the filtered timeline will be removed from the // "state" section and kept in "timeline". - - // update the powerlevel event for timeline events - for i, ev := range events { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - var newEvent gomatrixserverlib.PDU - newEvent, err = p.updatePowerLevelEvent(ctx, ev, eventFormat) - if err != nil { - return r.From, err - } - events[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)), gomatrixserverlib.TopologicalOrderByAuthEvents, @@ -390,15 +370,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( continue } delta.StateEvents[i-skipped] = he - // update the powerlevel event for state events - if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") { - var newEvent gomatrixserverlib.PDU - newEvent, err = p.updatePowerLevelEvent(ctx, he, eventFormat) - if err != nil { - return r.From, err - } - delta.StateEvents[i-skipped] = &rstypes.HeaderedEvent{PDU: newEvent} - } } delta.StateEvents = delta.StateEvents[:len(sEvents)-skipped] @@ -468,79 +439,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( return latestPosition, nil } -func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (gomatrixserverlib.PDU, error) { - pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev) - if err != nil { - return nil, err - } - newPls := make(map[string]int64) - var userID *spec.UserID - for user, level := range pls.Users { - if eventFormat != synctypes.FormatSyncFederation { - userID, err = p.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(user)) - if err != nil { - return nil, err - } - user = userID.String() - } - newPls[user] = level - } - var newPlBytes, newEv []byte - newPlBytes, err = json.Marshal(newPls) - if err != nil { - return nil, err - } - newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes) - if err != nil { - return nil, err - } - - // do the same for prev content - prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content") - if !prevContent.Exists() { - var evNew gomatrixserverlib.PDU - evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false) - if err != nil { - return nil, err - } - - return evNew, err - } - pls = gomatrixserverlib.PowerLevelContent{} - err = json.Unmarshal([]byte(prevContent.Raw), &pls) - if err != nil { - return nil, err - } - - newPls = make(map[string]int64) - for user, level := range pls.Users { - if eventFormat != synctypes.FormatSyncFederation { - userID, err = p.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(user)) - if err != nil { - return nil, err - } - user = userID.String() - } - newPls[user] = level - } - newPlBytes, err = json.Marshal(newPls) - if err != nil { - return nil, err - } - newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) - if err != nil { - return nil, err - } - - var evNew gomatrixserverlib.PDU - evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false) - if err != nil { - return nil, err - } - - return evNew, err -} - // applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make // sure we always return the required events in the timeline. func applyHistoryVisibilityFilter( @@ -690,35 +588,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } - // Update powerlevel events for timeline events - for i, ev := range events { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat) - if err != nil { - return nil, err - } - events[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - // Update powerlevel events for state events - for i, ev := range stateEvents { - if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs { - continue - } - if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") { - continue - } - newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat) - if err != nil { - return nil, err - } - stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent} - } - jr.Timeline.PrevBatch = prevBatch jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index e0616e11d6..6812f83322 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // PrevEventRef represents a reference to a previous event in a state event upgrade @@ -78,59 +80,62 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, if se == nil { continue // TODO: shouldn't happen? } - if format == FormatSyncFederation { - evs = append(evs, ToClientEvent(se, format, string(se.SenderID()), se.StateKey(), spec.RawJSON(se.Unsigned()))) + ev, err := ToClientEvent(se, format, userIDForSender) + if err != nil { + logrus.WithError(err).Warn("Failed converting event to ClientEvent") continue } + evs = append(evs, *ev) + } + return evs +} - sender := spec.UserID{} - userID, err := userIDForSender(se.RoomID(), se.SenderID()) - if err == nil && userID != nil { - sender = *userID - } +// ToClientEventDefault converts a single server event to a client event. +// It provides default logic for event.SenderID & event.StateKey -> userID conversions. +func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { + ev, err := ToClientEvent(event, FormatAll, userIDQuery) + if err != nil { + return ClientEvent{} + } + return *ev +} - sk := se.StateKey() - if sk != nil && *sk != "" { - skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) - if err == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } +// If provided state key is a user ID (state keys beginning with @ are reserved for this purpose) +// fetch it's associated sender ID and use that instead. Otherwise returns the same state key back. +// +// # This function either returns the state key that should be used, or an error +// +// TODO: handle failure cases better (e.g. no sender ID) +func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec.SenderIDForUser) (*string, error) { + if len(stateKey) >= 1 && stateKey[0] == '@' { + parsedStateKey, err := spec.NewUserID(stateKey, true) + if err != nil { + // If invalid user ID, then there is no associated state event. + return nil, fmt.Errorf("Provided state key begins with @ but is not a valid user ID: %w", err) } - - unsigned := se.Unsigned() - var prev PrevEventRef - if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" { - prevUserID, err := userIDForSender(se.RoomID(), spec.SenderID(prev.PrevSenderID)) - if err == nil && userID != nil { - prev.PrevSenderID = prevUserID.String() - } else { - errString := "userID unknown" - if err != nil { - errString = err.Error() - } - logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString) - // NOTE: Not much can be done here, so leave the previous value in place. - } - unsigned, err = json.Marshal(prev) - if err != nil { - logrus.Errorf("Failed to marshal unsigned content for ClientEvent: %s", err.Error()) - continue - } + senderID, err := senderIDQuery(roomID, *parsedStateKey) + if err != nil { + return nil, fmt.Errorf("Failed to query sender ID: %w", err) + } + if senderID == nil { + // If no sender ID, then there is no associated state event. + return nil, fmt.Errorf("No associated sender ID found.") } - evs = append(evs, ToClientEvent(se, format, sender.String(), sk, spec.RawJSON(unsigned))) + newStateKey := string(*senderID) + return &newStateKey, nil + } else { + return &stateKey, nil } - return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender string, stateKey *string, unsigned spec.RawJSON) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) (*ClientEvent, error) { ce := ClientEvent{ - Content: spec.RawJSON(se.Content()), - Sender: sender, + Content: se.Content(), + Sender: string(se.SenderID()), Type: se.Type(), - StateKey: stateKey, - Unsigned: unsigned, + StateKey: se.StateKey(), + Unsigned: se.Unsigned(), OriginServerTS: se.OriginServerTS(), EventID: se.EventID(), Redacts: se.Redacts(), @@ -148,58 +153,268 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender st // TODO: Set Signatures & Hashes fields } - if format != FormatSyncFederation { - if se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { - ce.SenderKey = se.SenderID() + if format != FormatSyncFederation && se.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + err := updatePseudoIDs(&ce, se, userIDForSender, format) + if err != nil { + return nil, err } } - return ce + + return &ce, nil } -// ToClientEvent converts a single server event to a client event. -// It provides default logic for event.SenderID & event.StateKey -> userID conversions. -func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent { - sender := spec.UserID{} - userID, err := userIDQuery(event.RoomID(), event.SenderID()) +func updatePseudoIDs(ce *ClientEvent, se gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender, format ClientEventFormat) error { + ce.SenderKey = se.SenderID() + + userID, err := userIDForSender(se.RoomID(), se.SenderID()) if err == nil && userID != nil { - sender = *userID + ce.Sender = userID.String() } - sk := event.StateKey() + sk := se.StateKey() if sk != nil && *sk != "" { - skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey())) + skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk)) if err == nil && skUserID != nil { skString := skUserID.String() - sk = &skString + ce.StateKey = &skString } } - return ToClientEvent(event, FormatAll, sender.String(), sk, event.Unsigned()) + + var prev PrevEventRef + if err := json.Unmarshal(se.Unsigned(), &prev); err == nil && prev.PrevSenderID != "" { + prevUserID, err := userIDForSender(se.RoomID(), spec.SenderID(prev.PrevSenderID)) + if err == nil && userID != nil { + prev.PrevSenderID = prevUserID.String() + } else { + errString := "userID unknown" + if err != nil { + errString = err.Error() + } + logrus.Warnf("Failed to find userID for prev_sender in ClientEvent: %s", errString) + // NOTE: Not much can be done here, so leave the previous value in place. + } + ce.Unsigned, err = json.Marshal(prev) + if err != nil { + err = fmt.Errorf("Failed to marshal unsigned content for ClientEvent: %w", err) + return err + } + } + + switch se.Type() { + case spec.MRoomCreate: + updatedContent, err := updateCreateEvent(se.Content(), userIDForSender, se.RoomID()) + if err != nil { + err = fmt.Errorf("Failed to update m.room.create event for ClientEvent: %w", err) + return err + } + ce.Content = updatedContent + case spec.MRoomMember: + updatedEvent, err := updateInviteEvent(userIDForSender, se, format) + if err != nil { + err = fmt.Errorf("Failed to update m.room.member event for ClientEvent: %w", err) + return err + } + if updatedEvent != nil { + ce.Unsigned = updatedEvent.Unsigned() + } + case spec.MRoomPowerLevels: + updatedEvent, err := updatePowerLevelEvent(userIDForSender, se, format) + if err != nil { + err = fmt.Errorf("Failed update m.room.power_levels event for ClientEvent: %w", err) + return err + } + if updatedEvent != nil { + ce.Content = updatedEvent.Content() + ce.Unsigned = updatedEvent.Unsigned() + } + } + + return nil } -// If provided state key is a user ID (state keys beginning with @ are reserved for this purpose) -// fetch it's associated sender ID and use that instead. Otherwise returns the same state key back. -// -// # This function either returns the state key that should be used, or an error -// -// TODO: handle failure cases better (e.g. no sender ID) -func FromClientStateKey(roomID spec.RoomID, stateKey string, senderIDQuery spec.SenderIDForUser) (*string, error) { - if len(stateKey) >= 1 && stateKey[0] == '@' { - parsedStateKey, err := spec.NewUserID(stateKey, true) +func updateCreateEvent(content spec.RawJSON, userIDForSender spec.UserIDForSender, roomID spec.RoomID) (spec.RawJSON, error) { + if creator := gjson.GetBytes(content, "creator"); creator.Exists() { + oldCreator := creator.Str + userID, err := userIDForSender(roomID, spec.SenderID(oldCreator)) if err != nil { - // If invalid user ID, then there is no associated state event. - return nil, fmt.Errorf("Provided state key begins with @ but is not a valid user ID: %s", err.Error()) + err = fmt.Errorf("Failed to find userID for creator in ClientEvent: %w", err) + return nil, err } - senderID, err := senderIDQuery(roomID, *parsedStateKey) + + if userID != nil { + var newCreatorBytes, newContent []byte + newCreatorBytes, err = json.Marshal(userID.String()) + if err != nil { + err = fmt.Errorf("Failed to marshal new creator for ClientEvent: %w", err) + return nil, err + } + + newContent, err = sjson.SetRawBytes([]byte(content), "creator", newCreatorBytes) + if err != nil { + err = fmt.Errorf("Failed to set new creator for ClientEvent: %w", err) + return nil, err + } + + return newContent, nil + } + } + + return content, nil +} + +func updateInviteEvent(userIDForSender spec.UserIDForSender, ev gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) { + if inviteRoomState := gjson.GetBytes(ev.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { + userID, err := userIDForSender(ev.RoomID(), ev.SenderID()) + if err != nil || userID == nil { + if err != nil { + err = fmt.Errorf("invalid userID found when updating invite_room_state: %w", err) + } + return nil, err + } + + newState, err := GetUpdatedInviteRoomState(userIDForSender, inviteRoomState, ev, ev.RoomID(), eventFormat) if err != nil { - return nil, fmt.Errorf("Failed to query sender ID: %s", err.Error()) + return nil, err } - if senderID == nil { - // If no sender ID, then there is no associated state event. - return nil, fmt.Errorf("No associated sender ID found.") + + var newEv []byte + newEv, err = sjson.SetRawBytes(ev.JSON(), "unsigned.invite_room_state", newState) + if err != nil { + return nil, err } - newStateKey := string(*senderID) - return &newStateKey, nil - } else { - return &stateKey, nil + + return gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSON(newEv, false) + } + + return ev, nil +} + +type InviteRoomStateEvent struct { + Content spec.RawJSON `json:"content"` + SenderID string `json:"sender"` + StateKey *string `json:"state_key"` + Type string `json:"type"` +} + +func GetUpdatedInviteRoomState(userIDForSender spec.UserIDForSender, inviteRoomState gjson.Result, event gomatrixserverlib.PDU, roomID spec.RoomID, eventFormat ClientEventFormat) (spec.RawJSON, error) { + var res spec.RawJSON + inviteStateEvents := []InviteRoomStateEvent{} + err := json.Unmarshal([]byte(inviteRoomState.Raw), &inviteStateEvents) + if err != nil { + return nil, err } + + if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != FormatSyncFederation { + for i, ev := range inviteStateEvents { + userID, userIDErr := userIDForSender(roomID, spec.SenderID(ev.SenderID)) + if userIDErr != nil { + return nil, userIDErr + } + if userID != nil { + inviteStateEvents[i].SenderID = userID.String() + } + + if ev.StateKey != nil && *ev.StateKey != "" { + userID, senderErr := userIDForSender(roomID, spec.SenderID(*ev.StateKey)) + if senderErr != nil { + return nil, senderErr + } + if userID != nil { + user := userID.String() + inviteStateEvents[i].StateKey = &user + } + } + + updatedContent, updateErr := updateCreateEvent(ev.Content, userIDForSender, roomID) + if updateErr != nil { + updateErr = fmt.Errorf("Failed to update m.room.create event for ClientEvent: %w", userIDErr) + return nil, updateErr + } + inviteStateEvents[i].Content = updatedContent + } + } + + res, err = json.Marshal(inviteStateEvents) + if err != nil { + return nil, err + } + + return res, nil +} + +func updatePowerLevelEvent(userIDForSender spec.UserIDForSender, se gomatrixserverlib.PDU, eventFormat ClientEventFormat) (gomatrixserverlib.PDU, error) { + if !se.StateKeyEquals("") { + return se, nil + } + + pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(se) + if err != nil { + return nil, err + } + newPls := make(map[string]int64) + var userID *spec.UserID + for user, level := range pls.Users { + if eventFormat != FormatSyncFederation { + userID, err = userIDForSender(se.RoomID(), spec.SenderID(user)) + if err != nil { + return nil, err + } + user = userID.String() + } + newPls[user] = level + } + var newPlBytes, newEv []byte + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(se.JSON(), "content.users", newPlBytes) + if err != nil { + return nil, err + } + + // do the same for prev content + prevContent := gjson.GetBytes(se.JSON(), "unsigned.prev_content") + if !prevContent.Exists() { + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSON(newEv, false) + if err != nil { + return nil, err + } + + return evNew, err + } + pls = gomatrixserverlib.PowerLevelContent{} + err = json.Unmarshal([]byte(prevContent.Raw), &pls) + if err != nil { + return nil, err + } + + newPls = make(map[string]int64) + for user, level := range pls.Users { + if eventFormat != FormatSyncFederation { + userID, err = userIDForSender(se.RoomID(), spec.SenderID(user)) + if err != nil { + return nil, err + } + user = userID.String() + } + newPls[user] = level + } + newPlBytes, err = json.Marshal(newPls) + if err != nil { + return nil, err + } + newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes) + if err != nil { + return nil, err + } + + var evNew gomatrixserverlib.PDU + evNew, err = gomatrixserverlib.MustGetRoomVersion(se.Version()).NewEventFromTrustedJSONWithEventID(se.EventID(), newEv, false) + if err != nil { + return nil, err + } + + return evNew, err } diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index 202c185f1b..662f9ea43b 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -26,6 +26,14 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) +func queryUserIDForSender(senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + const testSenderID = "testSenderID" const testUserID = "@test:localhost" @@ -106,7 +114,12 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo t.Fatalf("failed to create userID: %s", err) } sk := "" - ce := ToClientEvent(ev, FormatAll, userID.String(), &sk, ev.Unsigned()) + ce, err := ToClientEvent(ev, FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) + } verifyEventFields(t, EventFieldsToVerify{ @@ -161,12 +174,12 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create Event: %s", err) } - userID, err := spec.NewUserID("@test:localhost", true) + ce, err := ToClientEvent(ev, FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) if err != nil { - t.Fatalf("failed to create userID: %s", err) + t.Fatalf("failed to create ClientEvent: %s", err) } - sk := "" - ce := ToClientEvent(ev, FormatSync, userID.String(), &sk, ev.Unsigned()) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } @@ -206,7 +219,12 @@ func TestToClientEventFormatSyncFederation(t *testing.T) { // nolint: gocyclo t.Fatalf("failed to create userID: %s", err) } sk := "" - ce := ToClientEvent(ev, FormatSyncFederation, userID.String(), &sk, ev.Unsigned()) + ce, err := ToClientEvent(ev, FormatSyncFederation, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) + if err != nil { + t.Fatalf("failed to create ClientEvent: %s", err) + } verifyEventFields(t, EventFieldsToVerify{ diff --git a/syncapi/types/types.go b/syncapi/types/types.go index b90c128c34..bca11855c9 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -15,6 +15,7 @@ package types import ( + "context" "encoding/json" "errors" "fmt" @@ -532,7 +533,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string, eventFormat synctypes.ClientEventFormat) *InviteResponse { +func NewInviteResponse(ctx context.Context, rsAPI api.QuerySenderIDAPI, event *types.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (*InviteResponse, error) { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -540,18 +541,42 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey // If there is then unmarshal it into the response. This will contain the // partial room state such as join rules, room name etc. if inviteRoomState := gjson.GetBytes(event.Unsigned(), "invite_room_state"); inviteRoomState.Exists() { - _ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events) + if event.Version() == gomatrixserverlib.RoomVersionPseudoIDs && eventFormat != synctypes.FormatSyncFederation { + updatedInvite, err := synctypes.GetUpdatedInviteRoomState(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, inviteRoomState, event.PDU, event.RoomID(), eventFormat) + if err != nil { + return nil, err + } + _ = json.Unmarshal(updatedInvite, &res.InviteState.Events) + } else { + _ = json.Unmarshal([]byte(inviteRoomState.Raw), &res.InviteState.Events) + } + } + + // Clear unsigned so it doesn't have pseudoIDs converted during ToClientEvent + eventNoUnsigned, err := event.SetUnsigned(nil) + if err != nil { + return nil, err } // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, eventFormat, userID.String(), stateKey, event.Unsigned()) + inviteEvent, err := synctypes.ToClientEvent(eventNoUnsigned, eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if err != nil { + return nil, err + } + + // Ensure unsigned field is empty so it isn't marshalled into the final JSON inviteEvent.Unsigned = nil - if ev, err := json.Marshal(inviteEvent); err == nil { + + if ev, err := json.Marshal(*inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) } - return &res + return &res, nil } // LeaveResponse represents a /sync response for a room which is under the 'leave' key. diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index a79b9fc5d8..35e1882cba 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -1,6 +1,7 @@ package types import ( + "context" "encoding/json" "reflect" "testing" @@ -11,8 +12,19 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +type FakeRoomserverAPI struct{} + +func (f *FakeRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + +func (f *FakeRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) { + sender := spec.SenderID(userID.String()) + return &sender, nil } func TestSyncTokens(t *testing.T) { @@ -61,25 +73,18 @@ func TestNewInviteResponse(t *testing.T) { t.Fatal(err) } - sender, err := spec.NewUserID("@neilalexander:matrix.org", true) + rsAPI := FakeRoomserverAPI{} + res, err := NewInviteResponse(context.Background(), &rsAPI, &types.HeaderedEvent{PDU: ev}, synctypes.FormatSync) if err != nil { t.Fatal(err) } - skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true) - if err != nil { - t.Fatal(err) - } - skString := skUserID.String() - sk := &skString - - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk, synctypes.FormatSync) j, err := json.Marshal(res) if err != nil { t.Fatal(err) } if string(j) != expected { - t.Fatalf("Invite response didn't contain correct info") + t.Fatalf("Invite response didn't contain correct info, \nexpected: %s \ngot: %s", expected, string(j)) } } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index dca09193c4..047fe9216b 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -301,25 +301,14 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - sender := spec.UserID{} - userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if queryErr == nil && userID != nil { - sender = *userID - } - - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*sk)) - if queryErr == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } else { - return fmt.Errorf("queryUserIDForSender: userID unknown for %s", *sk) - } + cevent, clientEvErr := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if clientEvErr != nil { + return clientEvErr } - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()) var member *localMembership - member, err = newLocalMembership(&cevent) + member, err = newLocalMembership(cevent) if err != nil { return fmt.Errorf("newLocalMembership: %w", err) } @@ -538,27 +527,19 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype if err != nil { return fmt.Errorf("s.localPushDevices: %w", err) } - - sender := spec.UserID{} - userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) - if err == nil && userID != nil { - sender = *userID + clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + if err != nil { + return err } - sk := event.StateKey() - if sk != nil && *sk != "" { - skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey())) - if queryErr == nil && skUserID != nil { - skString := skUserID.String() - sk = &skString - } - } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender.String(), sk, event.Unsigned()), + Event: *clientEvent, // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index 27e77cf7aa..2ea978d692 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -23,6 +23,14 @@ import ( userUtil "github.com/matrix-org/dendrite/userapi/util" ) +func queryUserIDForSender(senderID spec.SenderID) (*spec.UserID, error) { + if senderID == "" { + return nil, nil + } + + return spec.NewUserID(string(senderID), true) +} + func TestNotifyUserCountsAsync(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -100,13 +108,14 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Insert a dummy event - sender, err := spec.NewUserID(alice.ID, true) + ev, err := synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return queryUserIDForSender(senderID) + }) if err != nil { t.Error(err) } - sk := "" if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, sender.String(), &sk, dummyEvent.Unsigned()), + Event: *ev, }); err != nil { t.Error(err) }