Skip to content

Commit

Permalink
add basic support for inbox forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
dimkr committed Oct 1, 2023
1 parent faf9567 commit c5e4943
Show file tree
Hide file tree
Showing 11 changed files with 709 additions and 49 deletions.
4 changes: 4 additions & 0 deletions ap/activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ type Activity struct {

var ErrInvalidActivity = errors.New("Invalid activity")

func (a *Activity) IsPublic() bool {
return a.To.Contains(Public) || a.CC.Contains(Public)
}

func (a *Activity) UnmarshalJSON(b []byte) error {
var common anyActivity
if err := json.Unmarshal(b, &common); err != nil {
Expand Down
63 changes: 35 additions & 28 deletions fed/deliver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func ProcessQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *R
func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *Resolver) error {
log.Debug("Polling delivery queue")

rows, err := db.QueryContext(ctx, `select outbox.attempts, outbox.activity, persons.actor from outbox join persons on persons.id = outbox.activity->>'actor' where outbox.sent = 0 and (outbox.attempts = 0 or (outbox.attempts < ? and outbox.last <= unixepoch() - ?)) order by outbox.attempts asc, outbox.last asc limit ?`, MaxDeliveryAttempts, deliveryRetryInterval, batchSize)
rows, err := db.QueryContext(ctx, `select outbox.attempts, outbox.activity, persons.actor from outbox join persons on persons.id = outbox.sender where outbox.sent = 0 and (outbox.attempts = 0 or (outbox.attempts < ? and outbox.last <= unixepoch() - ?)) order by outbox.attempts asc, outbox.last asc limit ?`, MaxDeliveryAttempts, deliveryRetryInterval, batchSize)
if err != nil {
return fmt.Errorf("Failed to fetch posts to deliver: %w", err)
}
Expand Down Expand Up @@ -89,7 +89,7 @@ func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *R
continue
}

if err := deliverWithTimeout(ctx, log, db, resolver, &activity, &actor); err != nil {
if err := deliverWithTimeout(ctx, log, db, resolver, &activity, []byte(activityString), &actor); err != nil {
log.Warn("Failed to deliver activity", "id", activity.ID, "attempts", deliveryAttempts, "error", err)
continue
}
Expand All @@ -105,43 +105,42 @@ func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *R
return nil
}

func deliverWithTimeout(parent context.Context, log *slog.Logger, db *sql.DB, resolver *Resolver, activity *ap.Activity, actor *ap.Actor) error {
func deliverWithTimeout(parent context.Context, log *slog.Logger, db *sql.DB, resolver *Resolver, activity *ap.Activity, rawActivity []byte, actor *ap.Actor) error {
ctx, cancel := context.WithTimeout(parent, deliveryTimeout)
defer cancel()
return deliver(ctx, log, db, activity, actor, resolver)
return deliver(ctx, log, db, activity, rawActivity, actor, resolver)
}

func deliver(ctx context.Context, log *slog.Logger, db *sql.DB, activity *ap.Activity, actor *ap.Actor, resolver *Resolver) error {
buf, err := json.Marshal(activity)
if err != nil {
return fmt.Errorf("Failed to marshal activity: %w", err)
}
func deliver(ctx context.Context, log *slog.Logger, db *sql.DB, activity *ap.Activity, rawActivity []byte, actor *ap.Actor, resolver *Resolver) error {
isForwarded := activity.Actor != actor.ID

// deduplicate recipients
recipients := data.OrderedMap[string, struct{}]{}

activity.To.Range(func(id string, _ struct{}) bool {
recipients.Store(id, struct{}{})
return true
})
if !isForwarded {
activity.To.Range(func(id string, _ struct{}) bool {
recipients.Store(id, struct{}{})
return true
})

activity.CC.Range(func(id string, _ struct{}) bool {
recipients.Store(id, struct{}{})
return true
})
activity.CC.Range(func(id string, _ struct{}) bool {
recipients.Store(id, struct{}{})
return true
})
}

actorIDs := data.OrderedMap[string, struct{}]{}

// list the author's federated followers
if obj, ok := activity.Object.(*ap.Object); ok && obj.Type == ap.NoteObject && (obj.IsPublic() || recipients.Contains(actor.Followers)) {
// list the actor's federated followers if we're forwarding an activity by another actor, or if addressed by actor
if isForwarded || (activity.Actor == actor.ID && (activity.IsPublic() || recipients.Contains(actor.Followers))) {
followers, err := db.QueryContext(ctx, `select distinct follower from follows where followed = ? and follower not like ? and accepted = 1`, actor.ID, fmt.Sprintf("https://%s/%%", cfg.Domain))
if err != nil {
log.Warn("Failed to list followers", "post", obj.ID, "error", err)
log.Warn("Failed to list followers", "activity", activity.ID, "error", err)
} else {
for followers.Next() {
var follower string
if err := followers.Scan(&follower); err != nil {
log.Warn("Skipped a follower", "post", obj.ID, "error", err)
log.Warn("Skipped a follower", "activity", activity.ID, "error", err)
continue
}

Expand All @@ -153,27 +152,35 @@ func deliver(ctx context.Context, log *slog.Logger, db *sql.DB, activity *ap.Act
}

// assume that all other federated recipients are actors and not collections
prefix := fmt.Sprintf("https://%s/", cfg.Domain)
recipients.Range(func(recipient string, _ struct{}) bool {
if recipient != ap.Public && !strings.HasPrefix(recipient, prefix) {
actorIDs.Store(recipient, struct{}{})
}

actorIDs.Store(recipient, struct{}{})
return true
})

anyFailed := false

var author string
if obj, ok := activity.Object.(*ap.Object); ok {
author = obj.AttributedTo
}

prefix := fmt.Sprintf("https://%s/", cfg.Domain)

actorIDs.Range(func(actorID string, _ struct{}) bool {
if actorID == author || actorID == ap.Public || strings.HasPrefix(actorID, prefix) {
log.Debug("Skipping recipient", "to", actorID, "activity", activity.ID)
return true
}

log.Info("Delivering activity to recipient", "to", actorID, "activity", activity.ID)

if to, err := resolver.Resolve(ctx, log, db, actor, actorID, false); err != nil {
log.Warn("Failed to resolve a recipient", "to", actorID, "activity", activity.ID, "error", err)
if !errors.Is(err, ErrActorGone) && !errors.Is(err, ErrBlockedDomain) {
anyFailed = true
}
} else if err := Send(ctx, log, db, actor, resolver, to, buf); err != nil {
log.Warn("Failed to send a post", "to", actorID, "activity", activity.ID, "error", err)
} else if err := Send(ctx, log, db, actor, resolver, to, rawActivity); err != nil {
log.Warn("Failed to send an activity", "to", actorID, "activity", activity.ID, "error", err)
if !errors.Is(err, ErrBlockedDomain) {
anyFailed = true
}
Expand Down
99 changes: 84 additions & 15 deletions inbox/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,58 @@ const (
activitiesPollingInterval = time.Second * 5
activitiesBatchDelay = time.Millisecond * 100
activityProcessingTimeout = time.Second * 15
maxForwardingDepth = 5
)

func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, post *ap.Object, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error {
// a reply by B in a thread started by A is forwarded to all followers of A
func forwardActivity(ctx context.Context, log *slog.Logger, tx *sql.Tx, activity *ap.Activity, rawActivity []byte) error {
obj, ok := activity.Object.(*ap.Object)
if !ok {
return nil
}

var firstPostID, threadStarterID string
var depth int
if err := tx.QueryRowContext(ctx, `with recursive thread(id, author, parent, depth) as (select notes.id, notes.author, notes.object->>'inReplyTo' as parent, 1 as depth from notes where id = ? union select notes.id, notes.author, notes.object->>'inReplyTo' as parent, t.depth + 1 from thread t join notes on notes.id = t.parent) select id, author, depth from thread order by depth desc limit 1`, obj.ID).Scan(&firstPostID, &threadStarterID, &depth); err != nil && errors.Is(err, sql.ErrNoRows) {
log.Debug("Failed to find thread for post", "post", obj.ID)
return nil
} else if err != nil {
return fmt.Errorf("failed to fetch first post in thread: %w", err)
}
if depth > maxForwardingDepth {
log.Debug("Thread exceeds depth limit for forwarding")
return nil
}

prefix := fmt.Sprintf("https://%s/", cfg.Domain)
if !strings.HasPrefix(threadStarterID, prefix) {
log.Debug("Thread starter is federated")
return nil
}

var shouldForward int
if err := tx.QueryRowContext(ctx, `select exists (select 1 from notes join persons on persons.id = notes.author and (notes.public = 1 or exists (select 1 from json_each(notes.object->'to') where value = persons.actor->>'followers') or exists (select 1 from json_each(notes.object->'cc') where value = persons.actor->>'followers')) where notes.id = ?)`, firstPostID).Scan(&shouldForward); err != nil {
return err
}
if shouldForward == 0 {
log.Debug("Activity does not need to be forwarded")
return nil
}

if _, err := tx.ExecContext(
ctx,
`INSERT INTO outbox (activity, sender) VALUES(?,?)`,
string(rawActivity),
threadStarterID,
); err != nil {
return err
}

log.Info("Forwarding activity to followers of thread starter", "thread", firstPostID, "starter", threadStarterID)
return nil
}

func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, rawActivity []byte, post *ap.Object, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error {
prefix := fmt.Sprintf("https://%s/", cfg.Domain)
if strings.HasPrefix(sender.ID, prefix) || strings.HasPrefix(post.ID, prefix) || strings.HasPrefix(post.AttributedTo, prefix) || strings.HasPrefix(req.Actor, prefix) {
return fmt.Errorf("Received invalid Create for %s by %s from %s", post.ID, post.AttributedTo, req.Actor)
Expand All @@ -68,6 +117,11 @@ func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Act
if err := note.Insert(ctx, log, tx, post); err != nil {
return fmt.Errorf("Cannot insert %s: %w", post.ID, err)
}

if err := forwardActivity(ctx, log, tx, req, rawActivity); err != nil {
return fmt.Errorf("Cannot forward %s: %w", post.ID, err)
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("Cannot insert %s: %w", post.ID, err)
}
Expand All @@ -93,7 +147,7 @@ func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Act
return nil
}

func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error {
func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, rawActivity []byte, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error {
log.Debug("Processing activity")

switch req.Type {
Expand Down Expand Up @@ -205,7 +259,7 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re
return errors.New("Received invalid Create")
}

return processCreateActivity(ctx, log, sender, req, post, db, resolver, from)
return processCreateActivity(ctx, log, sender, req, rawActivity, post, db, resolver, from)

case ap.AnnounceActivity:
create, ok := req.Object.(*ap.Activity)
Expand All @@ -228,22 +282,23 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re
return errors.New("Sender is not post author or recipient")
}

return processCreateActivity(ctx, log, sender, create, post, db, resolver, from)
return processCreateActivity(ctx, log, sender, create, rawActivity, post, db, resolver, from)

case ap.UpdateActivity:
post, ok := req.Object.(*ap.Object)
if !ok || post.ID == "" || post.AttributedTo == "" {
return errors.New("Received invalid Update")
}

if sender.ID != post.AttributedTo {
prefix := fmt.Sprintf("https://%s/", cfg.Domain)
if strings.HasPrefix(post.ID, prefix) {
return fmt.Errorf("%s cannot update posts by %s", sender.ID, post.AttributedTo)
}

var lastUpdate sql.NullInt64
if err := db.QueryRowContext(ctx, `select max(inserted, updated) from notes where id = ? and author = ?`, post.ID, post.AttributedTo).Scan(&lastUpdate); err != nil && errors.Is(err, sql.ErrNoRows) {
log.Debug("Received Update for non-existing post")
return processCreateActivity(ctx, log, sender, req, post, db, resolver, from)
return processCreateActivity(ctx, log, sender, req, rawActivity, post, db, resolver, from)
} else if err != nil {
return fmt.Errorf("Failed to get last update time for %s: %w", post.ID, err)
}
Expand All @@ -258,7 +313,13 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re
return fmt.Errorf("Failed to update post %s: %w", post.ID, err)
}

if _, err := db.ExecContext(
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("Cannot insert %s: %w", post.ID, err)
}
defer tx.Rollback()

if _, err := tx.ExecContext(
ctx,
`update notes set object = ?, updated = unixepoch() where id = ?`,
string(body),
Expand All @@ -267,6 +328,14 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re
return fmt.Errorf("Failed to update post %s: %w", post.ID, err)
}

if err := forwardActivity(ctx, log, tx, req, rawActivity); err != nil {
return fmt.Errorf("Failed to forward update pos %s: %w", post.ID, err)
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("Failed to update post %s: %w", post.ID, err)
}

log.Info("Updated post")

case ap.LikeActivity:
Expand All @@ -283,24 +352,24 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re
return nil
}

func processActivityWithTimeout(parent context.Context, log *slog.Logger, sender *ap.Actor, activity *ap.Activity, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) {
func processActivityWithTimeout(parent context.Context, log *slog.Logger, sender *ap.Actor, activity *ap.Activity, rawActivity []byte, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) {
ctx, cancel := context.WithTimeout(parent, activityProcessingTimeout)
defer cancel()

if o, ok := activity.Object.(*ap.Object); ok {
log = log.With(slog.Group("activity", "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "object", "id", o.ID, "type", o.Type, "attributed_to", o.AttributedTo)))
log = log.With(slog.Group("activity", "id", activity.ID, "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "object", "id", o.ID, "type", o.Type, "attributed_to", o.AttributedTo)))
} else if a, ok := activity.Object.(*ap.Activity); ok {
log = log.With(slog.Group("activity", "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "activity", "id", a.ID, "type", a.Type, "actor", a.Actor)))
log = log.With(slog.Group("activity", "id", activity.ID, "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "activity", "id", a.ID, "type", a.Type, "actor", a.Actor)))
} else if s, ok := activity.Object.(string); ok {
log = log.With(slog.Group("activity", "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "string", "id", s)))
log = log.With(slog.Group("activity", "id", activity.ID, "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "string", "id", s)))
}

if err := processActivity(ctx, log, sender, activity, db, resolver, from); err != nil {
if err := processActivity(ctx, log, sender, activity, rawActivity, db, resolver, from); err != nil {
log.Warn("Failed to process activity", "error", err)
}
}

func processBatch(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) (int, error) {
func ProcessBatch(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) (int, error) {
log.Debug("Polling activities queue")

rows, err := db.QueryContext(ctx, `select inbox.id, persons.actor, inbox.activity from (select * from inbox limit -1 offset case when (select count(*) from inbox) >= $1 then $1/10 else 0 end) inbox left join persons on persons.id = inbox.sender order by inbox.id limit $2`, maxActivitiesQueueSize, activitiesBatchSize)
Expand Down Expand Up @@ -352,7 +421,7 @@ func processBatch(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *f
return true
}

processActivityWithTimeout(ctx, log, &sender, &activity, db, resolver, from)
processActivityWithTimeout(ctx, log, &sender, &activity, []byte(activityString), db, resolver, from)
return true
})

Expand All @@ -368,7 +437,7 @@ func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *f
defer t.Stop()

for {
n, err := processBatch(ctx, log, db, resolver, from)
n, err := ProcessBatch(ctx, log, db, resolver, from)
if err != nil {
return err
}
Expand Down
18 changes: 18 additions & 0 deletions migrations/007_outboxsender.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package migrations

import (
"context"
"database/sql"
)

func outboxsender(ctx context.Context, tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, `ALTER TABLE outbox ADD COLUMN sender STRING`); err != nil {
return err
}

if _, err := tx.ExecContext(ctx, `UPDATE outbox SET sender = activity->>'actor'`); err != nil {
return err
}

return nil
}
3 changes: 2 additions & 1 deletion outbox/accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ func Accept(ctx context.Context, followed, follower, followID string, db *sql.DB

if _, err := tx.ExecContext(
ctx,
`INSERT INTO outbox (activity) VALUES(?)`,
`INSERT INTO outbox (activity, sender) VALUES(?,?)`,
string(accept),
followed,
); err != nil {
return fmt.Errorf("Failed to insert Accept: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion outbox/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func Create(ctx context.Context, log *slog.Logger, db *sql.DB, post *ap.Object,
return fmt.Errorf("Failed to insert note: %w", err)
}

if _, err = tx.ExecContext(ctx, `insert into outbox (activity) values(?)`, string(create)); err != nil {
if _, err = tx.ExecContext(ctx, `insert into outbox (activity, sender) values(?,?)`, string(create), author.ID); err != nil {
return fmt.Errorf("Failed to insert Create: %w", err)
}

Expand Down
3 changes: 2 additions & 1 deletion outbox/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ func Delete(ctx context.Context, db *sql.DB, note *ap.Object) error {

if _, err := tx.ExecContext(
ctx,
`INSERT INTO outbox (activity) VALUES (?)`,
`INSERT INTO outbox (activity, sender) VALUES (?,?)`,
string(delete),
note.AttributedTo,
); err != nil {
return fmt.Errorf("Failed to insert delete activity: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion outbox/edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ func Edit(ctx context.Context, db *sql.DB, note *ap.Object, newContent string) e

if _, err := tx.ExecContext(
ctx,
`INSERT INTO outbox (activity) VALUES(?)`,
`INSERT INTO outbox (activity, sender) VALUES(?,?)`,
string(update),
note.AttributedTo,
); err != nil {
return fmt.Errorf("Failed to insert update activity: %w", err)
}
Expand Down
Loading

0 comments on commit c5e4943

Please sign in to comment.