Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Move pseudoID ClientEvent hotswapping to a common location (#3199)
Browse files Browse the repository at this point in the history
Fixes a variety of issues where clients were receiving pseudoIDs in
places that should be userIDs.
This change makes pseudoIDs work with sliding sync & element x.

---------

Co-authored-by: Till <[email protected]>
  • Loading branch information
devonh and S7evinK authored Sep 15, 2023
1 parent 8245b24 commit db83789
Show file tree
Hide file tree
Showing 13 changed files with 421 additions and 328 deletions.
21 changes: 7 additions & 14 deletions clientapi/routing/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,23 +172,16 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
}
}
for _, ev := range stateAfterRes.StateEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
if err == nil && userID != nil {
sender = *userID
}

sk := ev.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
clientEvent, err := synctypes.ToClientEvent(ev, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed converting to ClientEvent")
continue
}
stateEvents = append(
stateEvents,
synctypes.ToClientEvent(ev, synctypes.FormatAll, sender.String(), sk, ev.Unsigned()),
*clientEvent,
)
}
}
Expand Down
6 changes: 4 additions & 2 deletions internal/eventutil/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,13 @@ func RedactEvent(ctx context.Context, redactionEvent, redactedEvent gomatrixserv
return fmt.Errorf("RedactEvent: redactionEvent isn't a redaction event, is '%s'", redactionEvent.Type())
}
redactedEvent.Redact()
senderID, err := querier.QueryUserIDForSender(ctx, redactedEvent.RoomID(), redactionEvent.SenderID())
clientEvent, err := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return querier.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
return err
}
redactedBecause := synctypes.ToClientEvent(redactionEvent, synctypes.FormatSync, senderID.String(), redactionEvent.StateKey(), redactionEvent.Unsigned())
redactedBecause := clientEvent
if err := redactedEvent.SetUnsignedField("redacted_because", redactedBecause); err != nil {
return err
}
Expand Down
18 changes: 6 additions & 12 deletions syncapi/routing/getevent.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,19 @@ func GetEvent(
}
}

senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, events[0].SenderID())
if err != nil || senderUserID == nil {
util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("QueryUserIDForSender errored or returned nil-user ID when user should be part of a room")
clientEvent, err := synctypes.ToClientEvent(events[0], synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}

sk := events[0].StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, senderUserID.String(), sk, events[0].Unsigned()),
JSON: *clientEvent,
}
}
21 changes: 7 additions & 14 deletions syncapi/routing/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,16 @@ func Relations(
// type if it was specified.
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID())
if err == nil && userID != nil {
sender = *userID
}

sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
clientEvent, err := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
})
if err != nil {
util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("Failed converting to ClientEvent")
continue
}
res.Chunk = append(
res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender.String(), sk, event.Unsigned()),
*clientEvent,
)
}

Expand Down
20 changes: 7 additions & 13 deletions syncapi/routing/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
profileInfos[userID.String()] = profile
}

sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
clientEvent, err := synctypes.ToClientEvent(event, synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
util.GetLogger(req.Context()).WithError(err).WithField("senderID", event.SenderID()).Error("Failed converting to ClientEvent")
continue
}

sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
Expand All @@ -257,7 +251,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender.String(), sk, event.Unsigned()),
Result: *clientEvent,
})
roomGroup := groups[event.RoomID().String()]
roomGroup.Results = append(roomGroup.Results, event.EventID())
Expand Down
15 changes: 5 additions & 10 deletions syncapi/streams/stream_invite.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,15 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender
}

sk := inviteEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}

// skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
ir := types.NewInviteResponse(inviteEvent, user, sk, eventFormat)
ir, err := types.NewInviteResponse(ctx, p.rsAPI, inviteEvent, eventFormat)
if err != nil {
req.Log.WithError(err).Error("failed creating invite response")
continue
}
req.Response.Rooms.Invite[roomID] = ir
}

Expand Down
131 changes: 0 additions & 131 deletions syncapi/streams/stream_pdu.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package streams
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"

Expand All @@ -16,8 +15,6 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"

"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/gomatrixserverlib"
Expand Down Expand Up @@ -359,23 +356,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// Now that we've filtered the timeline, work out which state events are still
// left. Anything that appears in the filtered timeline will be removed from the
// "state" section and kept in "timeline".

// update the powerlevel event for timeline events
for i, ev := range events {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
var newEvent gomatrixserverlib.PDU
newEvent, err = p.updatePowerLevelEvent(ctx, ev, eventFormat)
if err != nil {
return r.From, err
}
events[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}

sEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(removeDuplicates(delta.StateEvents, events)),
gomatrixserverlib.TopologicalOrderByAuthEvents,
Expand All @@ -390,15 +370,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
continue
}
delta.StateEvents[i-skipped] = he
// update the powerlevel event for state events
if ev.Version() == gomatrixserverlib.RoomVersionPseudoIDs && ev.Type() == spec.MRoomPowerLevels && ev.StateKeyEquals("") {
var newEvent gomatrixserverlib.PDU
newEvent, err = p.updatePowerLevelEvent(ctx, he, eventFormat)
if err != nil {
return r.From, err
}
delta.StateEvents[i-skipped] = &rstypes.HeaderedEvent{PDU: newEvent}
}
}
delta.StateEvents = delta.StateEvents[:len(sEvents)-skipped]

Expand Down Expand Up @@ -468,79 +439,6 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
return latestPosition, nil
}

func (p *PDUStreamProvider) updatePowerLevelEvent(ctx context.Context, ev *rstypes.HeaderedEvent, eventFormat synctypes.ClientEventFormat) (gomatrixserverlib.PDU, error) {
pls, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev)
if err != nil {
return nil, err
}
newPls := make(map[string]int64)
var userID *spec.UserID
for user, level := range pls.Users {
if eventFormat != synctypes.FormatSyncFederation {
userID, err = p.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(user))
if err != nil {
return nil, err
}
user = userID.String()
}
newPls[user] = level
}
var newPlBytes, newEv []byte
newPlBytes, err = json.Marshal(newPls)
if err != nil {
return nil, err
}
newEv, err = sjson.SetRawBytes(ev.JSON(), "content.users", newPlBytes)
if err != nil {
return nil, err
}

// do the same for prev content
prevContent := gjson.GetBytes(ev.JSON(), "unsigned.prev_content")
if !prevContent.Exists() {
var evNew gomatrixserverlib.PDU
evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false)
if err != nil {
return nil, err
}

return evNew, err
}
pls = gomatrixserverlib.PowerLevelContent{}
err = json.Unmarshal([]byte(prevContent.Raw), &pls)
if err != nil {
return nil, err
}

newPls = make(map[string]int64)
for user, level := range pls.Users {
if eventFormat != synctypes.FormatSyncFederation {
userID, err = p.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(user))
if err != nil {
return nil, err
}
user = userID.String()
}
newPls[user] = level
}
newPlBytes, err = json.Marshal(newPls)
if err != nil {
return nil, err
}
newEv, err = sjson.SetRawBytes(newEv, "unsigned.prev_content.users", newPlBytes)
if err != nil {
return nil, err
}

var evNew gomatrixserverlib.PDU
evNew, err = gomatrixserverlib.MustGetRoomVersion(ev.Version()).NewEventFromTrustedJSONWithEventID(ev.EventID(), newEv, false)
if err != nil {
return nil, err
}

return evNew, err
}

// applyHistoryVisibilityFilter gets the current room state and supplies it to ApplyHistoryVisibilityFilter, to make
// sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter(
Expand Down Expand Up @@ -690,35 +588,6 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement()
}

// Update powerlevel events for timeline events
for i, ev := range events {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat)
if err != nil {
return nil, err
}
events[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}
// Update powerlevel events for state events
for i, ev := range stateEvents {
if ev.Version() != gomatrixserverlib.RoomVersionPseudoIDs {
continue
}
if ev.Type() != spec.MRoomPowerLevels || !ev.StateKeyEquals("") {
continue
}
newEvent, err := p.updatePowerLevelEvent(ctx, ev, eventFormat)
if err != nil {
return nil, err
}
stateEvents[i] = &rstypes.HeaderedEvent{PDU: newEvent}
}

jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), eventFormat, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
Expand Down
Loading

0 comments on commit db83789

Please sign in to comment.