diff --git a/ap/activity.go b/ap/activity.go index 624d386f..1c13028a 100644 --- a/ap/activity.go +++ b/ap/activity.go @@ -56,6 +56,10 @@ type Activity struct { var ErrInvalidActivity = errors.New("Invalid activity") +func (a *Activity) IsPublic() bool { + return a.To.Contains(Public) || a.CC.Contains(Public) +} + func (a *Activity) UnmarshalJSON(b []byte) error { var common anyActivity if err := json.Unmarshal(b, &common); err != nil { diff --git a/fed/deliver.go b/fed/deliver.go index ac453211..50f24a2a 100644 --- a/fed/deliver.go +++ b/fed/deliver.go @@ -58,7 +58,7 @@ func ProcessQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *R func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *Resolver) error { log.Debug("Polling delivery queue") - rows, err := db.QueryContext(ctx, `select outbox.attempts, outbox.activity, persons.actor from outbox join persons on persons.id = outbox.activity->>'actor' where outbox.sent = 0 and (outbox.attempts = 0 or (outbox.attempts < ? and outbox.last <= unixepoch() - ?)) order by outbox.attempts asc, outbox.last asc limit ?`, MaxDeliveryAttempts, deliveryRetryInterval, batchSize) + rows, err := db.QueryContext(ctx, `select outbox.attempts, outbox.activity, persons.actor from outbox join persons on persons.id = outbox.sender where outbox.sent = 0 and (outbox.attempts = 0 or (outbox.attempts < ? and outbox.last <= unixepoch() - ?)) order by outbox.attempts asc, outbox.last asc limit ?`, MaxDeliveryAttempts, deliveryRetryInterval, batchSize) if err != nil { return fmt.Errorf("Failed to fetch posts to deliver: %w", err) } @@ -89,7 +89,7 @@ func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *R continue } - if err := deliverWithTimeout(ctx, log, db, resolver, &activity, &actor); err != nil { + if err := deliverWithTimeout(ctx, log, db, resolver, &activity, []byte(activityString), &actor); err != nil { log.Warn("Failed to deliver activity", "id", activity.ID, "attempts", deliveryAttempts, "error", err) continue } @@ -105,43 +105,42 @@ func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *R return nil } -func deliverWithTimeout(parent context.Context, log *slog.Logger, db *sql.DB, resolver *Resolver, activity *ap.Activity, actor *ap.Actor) error { +func deliverWithTimeout(parent context.Context, log *slog.Logger, db *sql.DB, resolver *Resolver, activity *ap.Activity, rawActivity []byte, actor *ap.Actor) error { ctx, cancel := context.WithTimeout(parent, deliveryTimeout) defer cancel() - return deliver(ctx, log, db, activity, actor, resolver) + return deliver(ctx, log, db, activity, rawActivity, actor, resolver) } -func deliver(ctx context.Context, log *slog.Logger, db *sql.DB, activity *ap.Activity, actor *ap.Actor, resolver *Resolver) error { - buf, err := json.Marshal(activity) - if err != nil { - return fmt.Errorf("Failed to marshal activity: %w", err) - } +func deliver(ctx context.Context, log *slog.Logger, db *sql.DB, activity *ap.Activity, rawActivity []byte, actor *ap.Actor, resolver *Resolver) error { + isForwarded := activity.Actor != actor.ID // deduplicate recipients recipients := data.OrderedMap[string, struct{}]{} - activity.To.Range(func(id string, _ struct{}) bool { - recipients.Store(id, struct{}{}) - return true - }) + if !isForwarded { + activity.To.Range(func(id string, _ struct{}) bool { + recipients.Store(id, struct{}{}) + return true + }) - activity.CC.Range(func(id string, _ struct{}) bool { - recipients.Store(id, struct{}{}) - return true - }) + activity.CC.Range(func(id string, _ struct{}) bool { + recipients.Store(id, struct{}{}) + return true + }) + } actorIDs := data.OrderedMap[string, struct{}]{} - // list the author's federated followers - if obj, ok := activity.Object.(*ap.Object); ok && obj.Type == ap.NoteObject && (obj.IsPublic() || recipients.Contains(actor.Followers)) { + // list the actor's federated followers if we're forwarding an activity by another actor, or if addressed by actor + if isForwarded || (activity.Actor == actor.ID && (activity.IsPublic() || recipients.Contains(actor.Followers))) { followers, err := db.QueryContext(ctx, `select distinct follower from follows where followed = ? and follower not like ? and accepted = 1`, actor.ID, fmt.Sprintf("https://%s/%%", cfg.Domain)) if err != nil { - log.Warn("Failed to list followers", "post", obj.ID, "error", err) + log.Warn("Failed to list followers", "activity", activity.ID, "error", err) } else { for followers.Next() { var follower string if err := followers.Scan(&follower); err != nil { - log.Warn("Skipped a follower", "post", obj.ID, "error", err) + log.Warn("Skipped a follower", "activity", activity.ID, "error", err) continue } @@ -153,18 +152,26 @@ func deliver(ctx context.Context, log *slog.Logger, db *sql.DB, activity *ap.Act } // assume that all other federated recipients are actors and not collections - prefix := fmt.Sprintf("https://%s/", cfg.Domain) recipients.Range(func(recipient string, _ struct{}) bool { - if recipient != ap.Public && !strings.HasPrefix(recipient, prefix) { - actorIDs.Store(recipient, struct{}{}) - } - + actorIDs.Store(recipient, struct{}{}) return true }) anyFailed := false + var author string + if obj, ok := activity.Object.(*ap.Object); ok { + author = obj.AttributedTo + } + + prefix := fmt.Sprintf("https://%s/", cfg.Domain) + actorIDs.Range(func(actorID string, _ struct{}) bool { + if actorID == author || actorID == ap.Public || strings.HasPrefix(actorID, prefix) { + log.Debug("Skipping recipient", "to", actorID, "activity", activity.ID) + return true + } + log.Info("Delivering activity to recipient", "to", actorID, "activity", activity.ID) if to, err := resolver.Resolve(ctx, log, db, actor, actorID, false); err != nil { @@ -172,8 +179,8 @@ func deliver(ctx context.Context, log *slog.Logger, db *sql.DB, activity *ap.Act if !errors.Is(err, ErrActorGone) && !errors.Is(err, ErrBlockedDomain) { anyFailed = true } - } else if err := Send(ctx, log, db, actor, resolver, to, buf); err != nil { - log.Warn("Failed to send a post", "to", actorID, "activity", activity.ID, "error", err) + } else if err := Send(ctx, log, db, actor, resolver, to, rawActivity); err != nil { + log.Warn("Failed to send an activity", "to", actorID, "activity", activity.ID, "error", err) if !errors.Is(err, ErrBlockedDomain) { anyFailed = true } diff --git a/inbox/queue.go b/inbox/queue.go index f3474da9..afa8ed98 100644 --- a/inbox/queue.go +++ b/inbox/queue.go @@ -39,9 +39,58 @@ const ( activitiesPollingInterval = time.Second * 5 activitiesBatchDelay = time.Millisecond * 100 activityProcessingTimeout = time.Second * 15 + maxForwardingDepth = 5 ) -func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, post *ap.Object, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error { +// a reply by B in a thread started by A is forwarded to all followers of A +func forwardActivity(ctx context.Context, log *slog.Logger, tx *sql.Tx, activity *ap.Activity, rawActivity []byte) error { + obj, ok := activity.Object.(*ap.Object) + if !ok { + return nil + } + + var firstPostID, threadStarterID string + var depth int + if err := tx.QueryRowContext(ctx, `with recursive thread(id, author, parent, depth) as (select notes.id, notes.author, notes.object->>'inReplyTo' as parent, 1 as depth from notes where id = ? union select notes.id, notes.author, notes.object->>'inReplyTo' as parent, t.depth + 1 from thread t join notes on notes.id = t.parent) select id, author, depth from thread order by depth desc limit 1`, obj.ID).Scan(&firstPostID, &threadStarterID, &depth); err != nil && errors.Is(err, sql.ErrNoRows) { + log.Debug("Failed to find thread for post", "post", obj.ID) + return nil + } else if err != nil { + return fmt.Errorf("failed to fetch first post in thread: %w", err) + } + if depth > maxForwardingDepth { + log.Debug("Thread exceeds depth limit for forwarding") + return nil + } + + prefix := fmt.Sprintf("https://%s/", cfg.Domain) + if !strings.HasPrefix(threadStarterID, prefix) { + log.Debug("Thread starter is federated") + return nil + } + + var shouldForward int + if err := tx.QueryRowContext(ctx, `select exists (select 1 from notes join persons on persons.id = notes.author and (notes.public = 1 or exists (select 1 from json_each(notes.object->'to') where value = persons.actor->>'followers') or exists (select 1 from json_each(notes.object->'cc') where value = persons.actor->>'followers')) where notes.id = ?)`, firstPostID).Scan(&shouldForward); err != nil { + return err + } + if shouldForward == 0 { + log.Debug("Activity does not need to be forwarded") + return nil + } + + if _, err := tx.ExecContext( + ctx, + `INSERT INTO outbox (activity, sender) VALUES(?,?)`, + string(rawActivity), + threadStarterID, + ); err != nil { + return err + } + + log.Info("Forwarding activity to followers of thread starter", "thread", firstPostID, "starter", threadStarterID) + return nil +} + +func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, rawActivity []byte, post *ap.Object, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error { prefix := fmt.Sprintf("https://%s/", cfg.Domain) if strings.HasPrefix(sender.ID, prefix) || strings.HasPrefix(post.ID, prefix) || strings.HasPrefix(post.AttributedTo, prefix) || strings.HasPrefix(req.Actor, prefix) { return fmt.Errorf("Received invalid Create for %s by %s from %s", post.ID, post.AttributedTo, req.Actor) @@ -68,6 +117,11 @@ func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Act if err := note.Insert(ctx, log, tx, post); err != nil { return fmt.Errorf("Cannot insert %s: %w", post.ID, err) } + + if err := forwardActivity(ctx, log, tx, req, rawActivity); err != nil { + return fmt.Errorf("Cannot forward %s: %w", post.ID, err) + } + if err := tx.Commit(); err != nil { return fmt.Errorf("Cannot insert %s: %w", post.ID, err) } @@ -93,7 +147,7 @@ func processCreateActivity(ctx context.Context, log *slog.Logger, sender *ap.Act return nil } -func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error { +func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, req *ap.Activity, rawActivity []byte, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) error { log.Debug("Processing activity") switch req.Type { @@ -205,7 +259,7 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re return errors.New("Received invalid Create") } - return processCreateActivity(ctx, log, sender, req, post, db, resolver, from) + return processCreateActivity(ctx, log, sender, req, rawActivity, post, db, resolver, from) case ap.AnnounceActivity: create, ok := req.Object.(*ap.Activity) @@ -228,7 +282,7 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re return errors.New("Sender is not post author or recipient") } - return processCreateActivity(ctx, log, sender, create, post, db, resolver, from) + return processCreateActivity(ctx, log, sender, create, rawActivity, post, db, resolver, from) case ap.UpdateActivity: post, ok := req.Object.(*ap.Object) @@ -236,14 +290,15 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re return errors.New("Received invalid Update") } - if sender.ID != post.AttributedTo { + prefix := fmt.Sprintf("https://%s/", cfg.Domain) + if strings.HasPrefix(post.ID, prefix) { return fmt.Errorf("%s cannot update posts by %s", sender.ID, post.AttributedTo) } var lastUpdate sql.NullInt64 if err := db.QueryRowContext(ctx, `select max(inserted, updated) from notes where id = ? and author = ?`, post.ID, post.AttributedTo).Scan(&lastUpdate); err != nil && errors.Is(err, sql.ErrNoRows) { log.Debug("Received Update for non-existing post") - return processCreateActivity(ctx, log, sender, req, post, db, resolver, from) + return processCreateActivity(ctx, log, sender, req, rawActivity, post, db, resolver, from) } else if err != nil { return fmt.Errorf("Failed to get last update time for %s: %w", post.ID, err) } @@ -258,7 +313,13 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re return fmt.Errorf("Failed to update post %s: %w", post.ID, err) } - if _, err := db.ExecContext( + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("Cannot insert %s: %w", post.ID, err) + } + defer tx.Rollback() + + if _, err := tx.ExecContext( ctx, `update notes set object = ?, updated = unixepoch() where id = ?`, string(body), @@ -267,6 +328,14 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re return fmt.Errorf("Failed to update post %s: %w", post.ID, err) } + if err := forwardActivity(ctx, log, tx, req, rawActivity); err != nil { + return fmt.Errorf("Failed to forward update pos %s: %w", post.ID, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("Failed to update post %s: %w", post.ID, err) + } + log.Info("Updated post") case ap.LikeActivity: @@ -283,24 +352,24 @@ func processActivity(ctx context.Context, log *slog.Logger, sender *ap.Actor, re return nil } -func processActivityWithTimeout(parent context.Context, log *slog.Logger, sender *ap.Actor, activity *ap.Activity, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) { +func processActivityWithTimeout(parent context.Context, log *slog.Logger, sender *ap.Actor, activity *ap.Activity, rawActivity []byte, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) { ctx, cancel := context.WithTimeout(parent, activityProcessingTimeout) defer cancel() if o, ok := activity.Object.(*ap.Object); ok { - log = log.With(slog.Group("activity", "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "object", "id", o.ID, "type", o.Type, "attributed_to", o.AttributedTo))) + log = log.With(slog.Group("activity", "id", activity.ID, "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "object", "id", o.ID, "type", o.Type, "attributed_to", o.AttributedTo))) } else if a, ok := activity.Object.(*ap.Activity); ok { - log = log.With(slog.Group("activity", "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "activity", "id", a.ID, "type", a.Type, "actor", a.Actor))) + log = log.With(slog.Group("activity", "id", activity.ID, "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "activity", "id", a.ID, "type", a.Type, "actor", a.Actor))) } else if s, ok := activity.Object.(string); ok { - log = log.With(slog.Group("activity", "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "string", "id", s))) + log = log.With(slog.Group("activity", "id", activity.ID, "sender", sender.ID, "type", activity.Type, "actor", activity.Actor, slog.Group("object", "kind", "string", "id", s))) } - if err := processActivity(ctx, log, sender, activity, db, resolver, from); err != nil { + if err := processActivity(ctx, log, sender, activity, rawActivity, db, resolver, from); err != nil { log.Warn("Failed to process activity", "error", err) } } -func processBatch(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) (int, error) { +func ProcessBatch(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *fed.Resolver, from *ap.Actor) (int, error) { log.Debug("Polling activities queue") rows, err := db.QueryContext(ctx, `select inbox.id, persons.actor, inbox.activity from (select * from inbox limit -1 offset case when (select count(*) from inbox) >= $1 then $1/10 else 0 end) inbox left join persons on persons.id = inbox.sender order by inbox.id limit $2`, maxActivitiesQueueSize, activitiesBatchSize) @@ -352,7 +421,7 @@ func processBatch(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *f return true } - processActivityWithTimeout(ctx, log, &sender, &activity, db, resolver, from) + processActivityWithTimeout(ctx, log, &sender, &activity, []byte(activityString), db, resolver, from) return true }) @@ -368,7 +437,7 @@ func processQueue(ctx context.Context, log *slog.Logger, db *sql.DB, resolver *f defer t.Stop() for { - n, err := processBatch(ctx, log, db, resolver, from) + n, err := ProcessBatch(ctx, log, db, resolver, from) if err != nil { return err } diff --git a/migrations/007_outboxsender.go b/migrations/007_outboxsender.go new file mode 100644 index 00000000..195c7f25 --- /dev/null +++ b/migrations/007_outboxsender.go @@ -0,0 +1,18 @@ +package migrations + +import ( + "context" + "database/sql" +) + +func outboxsender(ctx context.Context, tx *sql.Tx) error { + if _, err := tx.ExecContext(ctx, `ALTER TABLE outbox ADD COLUMN sender STRING`); err != nil { + return err + } + + if _, err := tx.ExecContext(ctx, `UPDATE outbox SET sender = activity->>'actor'`); err != nil { + return err + } + + return nil +} diff --git a/outbox/accept.go b/outbox/accept.go index 2305702f..8ceaf39a 100644 --- a/outbox/accept.go +++ b/outbox/accept.go @@ -54,8 +54,9 @@ func Accept(ctx context.Context, followed, follower, followID string, db *sql.DB if _, err := tx.ExecContext( ctx, - `INSERT INTO outbox (activity) VALUES(?)`, + `INSERT INTO outbox (activity, sender) VALUES(?,?)`, string(accept), + followed, ); err != nil { return fmt.Errorf("Failed to insert Accept: %w", err) } diff --git a/outbox/create.go b/outbox/create.go index 5a0a8102..e7383266 100644 --- a/outbox/create.go +++ b/outbox/create.go @@ -68,7 +68,7 @@ func Create(ctx context.Context, log *slog.Logger, db *sql.DB, post *ap.Object, return fmt.Errorf("Failed to insert note: %w", err) } - if _, err = tx.ExecContext(ctx, `insert into outbox (activity) values(?)`, string(create)); err != nil { + if _, err = tx.ExecContext(ctx, `insert into outbox (activity, sender) values(?,?)`, string(create), author.ID); err != nil { return fmt.Errorf("Failed to insert Create: %w", err) } diff --git a/outbox/delete.go b/outbox/delete.go index e0db5425..9207c5cb 100644 --- a/outbox/delete.go +++ b/outbox/delete.go @@ -74,8 +74,9 @@ func Delete(ctx context.Context, db *sql.DB, note *ap.Object) error { if _, err := tx.ExecContext( ctx, - `INSERT INTO outbox (activity) VALUES (?)`, + `INSERT INTO outbox (activity, sender) VALUES (?,?)`, string(delete), + note.AttributedTo, ); err != nil { return fmt.Errorf("Failed to insert delete activity: %w", err) } diff --git a/outbox/edit.go b/outbox/edit.go index c6465fd0..1c3af567 100644 --- a/outbox/edit.go +++ b/outbox/edit.go @@ -70,8 +70,9 @@ func Edit(ctx context.Context, db *sql.DB, note *ap.Object, newContent string) e if _, err := tx.ExecContext( ctx, - `INSERT INTO outbox (activity) VALUES(?)`, + `INSERT INTO outbox (activity, sender) VALUES(?,?)`, string(update), + note.AttributedTo, ); err != nil { return fmt.Errorf("Failed to insert update activity: %w", err) } diff --git a/outbox/follow.go b/outbox/follow.go index 2b5e6b9f..ba95c9c2 100644 --- a/outbox/follow.go +++ b/outbox/follow.go @@ -70,8 +70,9 @@ func Follow(ctx context.Context, follower *ap.Actor, followed string, db *sql.DB if _, err := tx.ExecContext( ctx, - `INSERT INTO outbox (activity) VALUES(?)`, + `INSERT INTO outbox (activity, sender) VALUES(?,?)`, string(body), + follower.ID, ); err != nil { return fmt.Errorf("Failed to insert follow activity: %w", err) } diff --git a/outbox/unfollow.go b/outbox/unfollow.go index e3ffbc60..83779f38 100644 --- a/outbox/unfollow.go +++ b/outbox/unfollow.go @@ -72,8 +72,9 @@ func Unfollow(ctx context.Context, log *slog.Logger, db *sql.DB, follower *ap.Ac if _, err := tx.ExecContext( ctx, - `INSERT INTO outbox (activity) VALUES(?)`, + `INSERT INTO outbox (activity, sender) VALUES(?,?)`, string(body), + follower.ID, ); err != nil { return fmt.Errorf("Failed to insert undo for %s: %w", followID, err) } diff --git a/test/forward_test.go b/test/forward_test.go new file mode 100644 index 00000000..a572f686 --- /dev/null +++ b/test/forward_test.go @@ -0,0 +1,557 @@ +/* +Copyright 2023 Dima Krasner + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "context" + "github.com/dimkr/tootik/ap" + "github.com/dimkr/tootik/fed" + "github.com/dimkr/tootik/inbox" + "github.com/dimkr/tootik/inbox/note" + "github.com/dimkr/tootik/outbox" + "github.com/stretchr/testify/assert" + "log/slog" + "testing" +) + +func TestForward_ReplyToPostByFollower(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + assert.NoError( + outbox.Accept( + context.Background(), + server.Alice.ID, + "https://127.0.0.1/user/dan", + "https://localhost.localdomain/follow/1", + server.db, + ), + ) + + to := ap.Audience{} + to.Add(server.Alice.Followers) + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/1", + Type: ap.NoteObject, + AttributedTo: server.Alice.ID, + Content: "hello", + To: to, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://localhost.localdomain/note/1","content":"bye","to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]},"to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ? and sender = ?)`, reply, server.Alice.ID).Scan(&forwarded)) + assert.Equal(1, forwarded) +} + +func TestForward_ReplyToPublicPost(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + assert.NoError( + outbox.Accept( + context.Background(), + server.Alice.ID, + "https://127.0.0.1/user/dan", + "https://localhost.localdomain/follow/1", + server.db, + ), + ) + + to := ap.Audience{} + to.Add(ap.Public) + + cc := ap.Audience{} + cc.Add(server.Alice.Followers) + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/1", + Type: ap.NoteObject, + AttributedTo: server.Alice.ID, + Content: "hello", + To: to, + CC: cc, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://localhost.localdomain/note/1","content":"bye","to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]},"to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ? and sender = ?)`, reply, server.Alice.ID).Scan(&forwarded)) + assert.Equal(1, forwarded) +} + +func TestForward_ReplyToReplyToPostByFollower(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + assert.NoError( + outbox.Accept( + context.Background(), + server.Alice.ID, + "https://127.0.0.1/user/dan", + "https://localhost.localdomain/follow/1", + server.db, + ), + ) + + to := ap.Audience{} + to.Add(server.Alice.Followers) + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/1", + Type: ap.NoteObject, + AttributedTo: server.Alice.ID, + Content: "hello", + To: to, + }, + ), + ) + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/2", + Type: ap.NoteObject, + AttributedTo: server.Bob.ID, + InReplyTo: "https://localhost.localdomain/note/1", + Content: "hola", + To: to, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://localhost.localdomain/note/2","content":"bye","to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/bob"]},"to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/bob"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ? and sender = ?)`, reply, server.Alice.ID).Scan(&forwarded)) + assert.Equal(1, forwarded) +} + +func TestForward_ReplyToUnknownPost(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + assert.NoError( + outbox.Accept( + context.Background(), + server.Alice.ID, + "https://127.0.0.1/user/dan", + "https://localhost.localdomain/follow/1", + server.db, + ), + ) + + to := ap.Audience{} + to.Add(server.Alice.Followers) + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/1", + Type: ap.NoteObject, + AttributedTo: server.Alice.ID, + Content: "hello", + To: to, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://localhost.localdomain/note/3","content":"bye","to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]},"to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ?)`, reply).Scan(&forwarded)) + assert.Equal(0, forwarded) +} + +func TestForward_ReplyToDM(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + assert.NoError( + outbox.Accept( + context.Background(), + server.Alice.ID, + "https://127.0.0.1/user/dan", + "https://localhost.localdomain/follow/1", + server.db, + ), + ) + + to := ap.Audience{} + to.Add(server.Bob.ID) + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/1", + Type: ap.NoteObject, + AttributedTo: server.Alice.ID, + Content: "hello", + To: to, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://localhost.localdomain/note/1","content":"bye","to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]},"to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ?)`, reply).Scan(&forwarded)) + assert.Equal(0, forwarded) +} + +func TestForward_NotFollowingAuthor(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + to := ap.Audience{} + to.Add(server.Alice.Followers) + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/1", + Type: ap.NoteObject, + AttributedTo: server.Alice.ID, + Content: "hello", + To: to, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://localhost.localdomain/note/1","content":"bye","to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]},"to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ? and sender = ?)`, reply, server.Alice.ID).Scan(&forwarded)) + assert.Equal(1, forwarded) +} + +func TestForward_NotReplyToLocalPost(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + assert.NoError( + outbox.Accept( + context.Background(), + server.Alice.ID, + "https://127.0.0.1/user/dan", + "https://localhost.localdomain/follow/1", + server.db, + ), + ) + + to := ap.Audience{} + to.Add(server.Alice.Followers) + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://localhost.localdomain/note/1", + Type: ap.NoteObject, + AttributedTo: server.Alice.ID, + Content: "hello", + To: to, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://127.0.0.1/note/2","content":"bye","to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]},"to":["https://localhost.localdomain/user/alice"],"cc":["https://localhost.localdomain/followers/alice"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ?)`, reply).Scan(&forwarded)) + assert.Equal(0, forwarded) +} + +func TestForward_ReplyToFederatedPost(t *testing.T) { + server := newTestServer() + defer server.Shutdown() + + assert := assert.New(t) + + to := ap.Audience{} + to.Add("https://127.0.0.1/followers/erin") + + tx, err := server.db.BeginTx(context.Background(), nil) + assert.NoError(err) + defer tx.Rollback() + + assert.NoError( + note.Insert( + context.Background(), + slog.Default(), + tx, + &ap.Object{ + ID: "https://127.0.0.1/note/1", + Type: ap.NoteObject, + AttributedTo: "https://127.0.0.1/user/erin", + Content: "hello", + To: to, + }, + ), + ) + + assert.NoError(tx.Commit()) + + _, err = server.db.Exec( + `insert into persons (id, hash, actor) values(?,?,?)`, + "https://127.0.0.1/user/dan", + "eab50d465047c1ccfc581759f33612c583486044f5de62b2a5e77e220c2f1ae3", + `{"type":"Person"}`, + ) + assert.NoError(err) + + reply := `{"@context":["https://www.w3.org/ns/activitystreams"],"id":"https://127.0.0.1/create/1","type":"Create","actor":"https://127.0.0.1/user/dan","object":{"id":"https://127.0.0.1/note/1","type":"Note","attributedTo":"https://127.0.0.1/user/dan","inReplyTo":"https://127.0.0.1/note/1","content":"bye","to":["https://127.0.0.1/user/erin"],"cc":["https://127.0.0.1/followers/erin"]},"to":["https://127.0.0.1/user/erin"],"cc":["https://127.0.0.1/followers/erin"]}` + + _, err = server.db.Exec( + `insert into inbox (sender, activity) values(?,?)`, + "https://127.0.0.1/user/dan", + reply, + ) + assert.NoError(err) + + n, err := inbox.ProcessBatch(context.Background(), slog.Default(), server.db, fed.NewResolver(nil), server.Nobody) + assert.NoError(err) + assert.Equal(1, n) + + var forwarded int + assert.NoError(server.db.QueryRow(`select exists (select 1 from outbox where activity = ?)`, reply).Scan(&forwarded)) + assert.Equal(0, forwarded) +}