diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..7d14cb6 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,3 @@ +run: + build-tags: + - wireinject \ No newline at end of file diff --git a/cmd/notification-service/di/inject_adapters.go b/cmd/notification-service/di/inject_adapters.go index 74fdb47..d401b32 100644 --- a/cmd/notification-service/di/inject_adapters.go +++ b/cmd/notification-service/di/inject_adapters.go @@ -29,7 +29,3 @@ var firestoreTxAdaptersSet = wire.NewSet( firestore.NewEventRepository, wire.Bind(new(app.EventRepository), new(*firestore.EventRepository)), ) - -var adaptersSet = wire.NewSet( -// adapters.NewCurrentTimeProvider, -) diff --git a/cmd/notification-service/di/inject_badger.go b/cmd/notification-service/di/inject_badger.go deleted file mode 100644 index 5222295..0000000 --- a/cmd/notification-service/di/inject_badger.go +++ /dev/null @@ -1,134 +0,0 @@ -package di - -import ( - "github.com/google/wire" -) - -var badgerUnpackTestDependenciesSet = wire.NewSet( -// wire.FieldsOf(new(badgeradapters.TestAdaptersDependencies), -// -// "BanListHasher", -// "CurrentTimeProvider", -// "RawMessageIdentifier", -// "LocalIdentity", -// -// ), -// wire.Bind(new(badgeradapters.BanListHasher), new(*mocks2.BanListHasherMock)), -// wire.Bind(new(commands.CurrentTimeProvider), new(*mocks2.CurrentTimeProviderMock)), -// wire.Bind(new(badgeradapters.RawMessageIdentifier), new(*mocks2.RawMessageIdentifierMock)), -) - -var badgerAdaptersSet = wire.NewSet( -// badgeradapters.NewGarbageCollector, -) - -var badgerNoTxRepositoriesSet = wire.NewSet( -// notx.NewNoTxBlobWantListRepository, -// wire.Bind(new(blobReplication.WantedBlobsProvider), new(*notx.NoTxBlobWantListRepository)), -// wire.Bind(new(blobReplication.WantListRepository), new(*notx.NoTxBlobWantListRepository)), -// -// notx.NewNoTxBlobsRepository, -// wire.Bind(new(blobReplication.BlobsRepository), new(*notx.NoTxBlobsRepository)), -// -// notx.NewNoTxFeedWantListRepository, -) - -var badgerRepositoriesSet = wire.NewSet( -// badgeradapters.NewBanListRepository, -// wire.Bind(new(commands.BanListRepository), new(*badgeradapters.BanListRepository)), -// wire.Bind(new(queries.BanListRepository), new(*badgeradapters.BanListRepository)), -// -// badgeradapters.NewBlobWantListRepository, -// wire.Bind(new(commands.BlobWantListRepository), new(*badgeradapters.BlobWantListRepository)), -// wire.Bind(new(blobReplication.WantListRepository), new(*badgeradapters.BlobWantListRepository)), -// -// badgeradapters.NewFeedWantListRepository, -// wire.Bind(new(commands.FeedWantListRepository), new(*badgeradapters.FeedWantListRepository)), -// wire.Bind(new(queries.FeedWantListRepository), new(*badgeradapters.FeedWantListRepository)), -// -// badgeradapters.NewReceiveLogRepository, -// wire.Bind(new(commands.ReceiveLogRepository), new(*badgeradapters.ReceiveLogRepository)), -// wire.Bind(new(queries.ReceiveLogRepository), new(*badgeradapters.ReceiveLogRepository)), -// -// badgeradapters.NewSocialGraphRepository, -// wire.Bind(new(commands.SocialGraphRepository), new(*badgeradapters.SocialGraphRepository)), -// wire.Bind(new(queries.SocialGraphRepository), new(*badgeradapters.SocialGraphRepository)), -// -// badgeradapters.NewFeedRepository, -// wire.Bind(new(commands.FeedRepository), new(*badgeradapters.FeedRepository)), -// wire.Bind(new(queries.FeedRepository), new(*badgeradapters.FeedRepository)), -// -// badgeradapters.NewMessageRepository, -// wire.Bind(new(queries.MessageRepository), new(*badgeradapters.MessageRepository)), -// -// badgeradapters.NewPubRepository, -// badgeradapters.NewBlobRepository, -) - -//var badgerTestAdaptersDependenciesSet = wire.NewSet( -// wire.Struct(new(badgeradapters.TestAdaptersDependencies), "*"), -// mocks2.NewBanListHasherMock, -// mocks2.NewCurrentTimeProviderMock, -// mocks2.NewRawMessageIdentifierMock, -//) -// -//var badgerNoTxTestTransactionProviderSet = wire.NewSet( -// notx.NewTestTxAdaptersFactoryTransactionProvider, -// wire.Bind(new(notx.TransactionProvider), new(*notx.TestTxAdaptersFactoryTransactionProvider)), -// -// noTxTestTxAdaptersFactory, -//) -// -//var testBadgerTransactionProviderSet = wire.NewSet( -// badgeradapters.NewTestTransactionProvider, -// testAdaptersFactory, -//) -// -//var badgerTransactionProviderSet = wire.NewSet( -// badgeradapters.NewCommandsTransactionProvider, -// wire.Bind(new(commands.TransactionProvider), new(*badgeradapters.CommandsTransactionProvider)), -// -// badgerCommandsAdaptersFactory, -// -// badgeradapters.NewQueriesTransactionProvider, -// wire.Bind(new(queries.TransactionProvider), new(*badgeradapters.QueriesTransactionProvider)), -// -// badgerQueriesAdaptersFactory, -//) -// -//var badgerNoTxTransactionProviderSet = wire.NewSet( -// notx.NewTxAdaptersFactoryTransactionProvider, -// wire.Bind(new(notx.TransactionProvider), new(*notx.TxAdaptersFactoryTransactionProvider)), -// -// noTxTxAdaptersFactory, -//) -// -//func noTxTestTxAdaptersFactory() notx.TestTxAdaptersFactory { -// return func(tx *badger.Txn, dependencies badgeradapters.TestAdaptersDependencies) (notx.TxAdapters, error) { -// return buildTestBadgerNoTxTxAdapters(tx, dependencies) -// } -//} -// -//func noTxTxAdaptersFactory(local identity.Public, conf service.Config, logger logging.Logger) notx.TxAdaptersFactory { -// return func(tx *badger.Txn) (notx.TxAdapters, error) { -// return buildBadgerNoTxTxAdapters(tx, local, conf, logger) -// } -//} -// -//func testAdaptersFactory() badgeradapters.TestAdaptersFactory { -// return func(tx *badger.Txn, dependencies badgeradapters.TestAdaptersDependencies) (badgeradapters.TestAdapters, error) { -// return buildBadgerTestAdapters(tx, dependencies) -// } -//} -// -//func badgerCommandsAdaptersFactory(config service.Config, local identity.Public, logger logging.Logger) badgeradapters.CommandsAdaptersFactory { -// return func(tx *badger.Txn) (commands.Adapters, error) { -// return buildBadgerCommandsAdapters(tx, local, config, logger) -// } -//} -// -//func badgerQueriesAdaptersFactory(config service.Config, local identity.Public, logger logging.Logger) badgeradapters.QueriesAdaptersFactory { -// return func(tx *badger.Txn) (queries.Adapters, error) { -// return buildBadgerQueriesAdapters(tx, local, config, logger) -// } -//} diff --git a/cmd/notification-service/di/inject_downloader.go b/cmd/notification-service/di/inject_downloader.go new file mode 100644 index 0000000..bd72ea2 --- /dev/null +++ b/cmd/notification-service/di/inject_downloader.go @@ -0,0 +1,10 @@ +package di + +import ( + "github.com/google/wire" + "github.com/planetary-social/go-notification-service/service/app" +) + +var downloaderSet = wire.NewSet( + app.NewDownloader, +) diff --git a/cmd/notification-service/di/service.go b/cmd/notification-service/di/service.go index 957cfba..524c8d1 100644 --- a/cmd/notification-service/di/service.go +++ b/cmd/notification-service/di/service.go @@ -10,17 +10,20 @@ import ( ) type Service struct { - app app.Application - server http.Server + app app.Application + server http.Server + downloader *app.Downloader } func NewService( app app.Application, server http.Server, + downloader *app.Downloader, ) Service { return Service{ - app: app, - server: server, + app: app, + server: server, + downloader: downloader, } } @@ -40,6 +43,11 @@ func (s Service) Run(ctx context.Context) error { errCh <- s.server.ListenAndServe(ctx) }() + runners++ + go func() { + errCh <- s.downloader.Run(ctx) + }() + var err error for i := 0; i < runners; i++ { err = multierror.Append(err, errors.Wrap(<-errCh, "error returned by runner")) diff --git a/cmd/notification-service/di/wire.go b/cmd/notification-service/di/wire.go index 6129434..f1ca572 100644 --- a/cmd/notification-service/di/wire.go +++ b/cmd/notification-service/di/wire.go @@ -19,6 +19,7 @@ func BuildService(context.Context, config.Config) (Service, func(), error) { portsSet, applicationSet, firestoreAdaptersSet, + downloaderSet, ) return Service{}, nil, nil } diff --git a/cmd/notification-service/di/wire_gen.go b/cmd/notification-service/di/wire_gen.go index 438492e..911cafe 100644 --- a/cmd/notification-service/di/wire_gen.go +++ b/cmd/notification-service/di/wire_gen.go @@ -40,7 +40,8 @@ func BuildService(contextContext context.Context, configConfig config.Config) (S Queries: queries, } server := http.NewServer(configConfig, application) - service := NewService(application, server) + downloader := app.NewDownloader(transactionProvider) + service := NewService(application, server, downloader) return service, func() { }, nil } diff --git a/cmd/notification-service/main.go b/cmd/notification-service/main.go index dd917bb..d7cafe8 100644 --- a/cmd/notification-service/main.go +++ b/cmd/notification-service/main.go @@ -1,9 +1,18 @@ package main import ( - "errors" + "context" "fmt" "os" + + "github.com/boreq/errors" + "github.com/nbd-wtf/go-nostr" + "github.com/nbd-wtf/go-nostr/nip19" + "github.com/planetary-social/go-notification-service/cmd/notification-service/di" + "github.com/planetary-social/go-notification-service/internal/fixtures" + "github.com/planetary-social/go-notification-service/service/app" + "github.com/planetary-social/go-notification-service/service/config" + "github.com/planetary-social/go-notification-service/service/domain" ) func main() { @@ -11,9 +20,94 @@ func main() { fmt.Printf("error: %s", err) os.Exit(1) } - } func run() error { - return errors.New("not implemented") + ctx := context.Background() + cfg, err := config.NewConfig("", "test-project-id") + if err != nil { + return errors.Wrap(err, "error creating a config") + } + + service, cleanup, err := di.BuildService(ctx, cfg) + if err != nil { + return errors.Wrap(err, "error building a service") + } + defer cleanup() + + addMyRegistration(ctx, service) // todo remove + + return service.Run(ctx) + +} + +func addMyRegistration(ctx context.Context, service di.Service) { + nsec := os.Getenv("NSEC") + _, value, err := nip19.Decode(nsec) + if err != nil { + panic(err) + } + + secretKey := value.(string) + publicKeyString, err := nostr.GetPublicKey(secretKey) + if err != nil { + panic(err) + } + + publicKey, err := domain.NewPublicKey(publicKeyString) + if err != nil { + panic(err) + } + + relayAddress, err := domain.NewRelayAddress("wss://relay.damus.io") + if err != nil { + panic(err) + } + + libEvent := nostr.Event{ + CreatedAt: nostr.Now(), + Kind: 12345, + Tags: nostr.Tags{}, + Content: fmt.Sprintf(` +{ + "publicKeys": [ + { + "publicKey": "%s", + "relays": [ + { + "address": "%s" + } + ] + } + ], + "locale": "%s", + "apnsToken": "%s" +} +`, + publicKey.Hex(), + relayAddress.String(), + fixtures.SomeString(), + fixtures.SomeString()), + } + + err = libEvent.Sign(secretKey) + if err != nil { + panic(err) + } + + event, err := domain.NewEvent(libEvent) + if err != nil { + panic(err) + } + + registration, err := domain.NewRegistrationFromEvent(event) + if err != nil { + panic(err) + } + + cmd := app.NewSaveRegistration(registration) + err = service.App().Commands.SaveRegistration.Handle(ctx, cmd) + if err != nil { + panic(err) + } } diff --git a/internal/set.go b/internal/set.go new file mode 100644 index 0000000..4dd8fde --- /dev/null +++ b/internal/set.go @@ -0,0 +1,48 @@ +package internal + +type Set[T comparable] struct { + values map[T]struct{} +} + +func NewEmptySet[T comparable]() *Set[T] { + return &Set[T]{ + values: make(map[T]struct{}), + } +} + +func NewSet[T comparable](values []T) *Set[T] { + v := NewEmptySet[T]() + for _, value := range values { + v.Put(value) + } + return v +} + +func (s *Set[T]) Contains(v T) bool { + _, ok := s.values[v] + return ok +} + +func (s *Set[T]) Put(v T) { + s.values[v] = struct{}{} +} + +func (s *Set[T]) Clear() { + s.values = make(map[T]struct{}) +} + +func (s *Set[T]) Delete(v T) { + delete(s.values, v) +} + +func (s *Set[T]) List() []T { + var result []T + for v := range s.values { + result = append(result, v) + } + return result +} + +func (s *Set[T]) Len() int { + return len(s.values) +} diff --git a/internal/set_test.go b/internal/set_test.go new file mode 100644 index 0000000..1024829 --- /dev/null +++ b/internal/set_test.go @@ -0,0 +1,49 @@ +package internal + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSet(t *testing.T) { + s := NewEmptySet[int]() + + require.False(t, s.Contains(0)) + require.Equal(t, 0, s.Len()) + require.Len(t, s.List(), 0) + + s.Put(0) + + require.True(t, s.Contains(0)) + require.Equal(t, 1, s.Len()) + require.Equal(t, []int{0}, s.List()) +} + +func TestSet_Delete(t *testing.T) { + s := NewEmptySet[int]() + + require.False(t, s.Contains(0)) + + s.Put(0) + + require.True(t, s.Contains(0)) + + s.Delete(0) + + require.False(t, s.Contains(0)) +} + +func TestSet_Clear(t *testing.T) { + s := NewEmptySet[int]() + + require.False(t, s.Contains(0)) + + s.Put(0) + + require.True(t, s.Contains(0)) + + s.Clear() + + require.False(t, s.Contains(0)) +} diff --git a/service/adapters/firestore/repository_event.go b/service/adapters/firestore/repository_event.go index 328fef7..068bf25 100644 --- a/service/adapters/firestore/repository_event.go +++ b/service/adapters/firestore/repository_event.go @@ -17,6 +17,6 @@ func NewEventRepository(client *firestore.Client, tx *firestore.Transaction) *Ev } func (e EventRepository) Save(relay domain.RelayAddress, event domain.Event) error { - fmt.Println("saving", relay, event) + fmt.Println("saving", relay, string(event.Content())) return nil } diff --git a/service/adapters/firestore/repository_registration.go b/service/adapters/firestore/repository_registration.go index 07bacd5..4ea033a 100644 --- a/service/adapters/firestore/repository_registration.go +++ b/service/adapters/firestore/repository_registration.go @@ -2,6 +2,7 @@ package firestore import ( "context" + "encoding/hex" "cloud.google.com/go/firestore" "github.com/boreq/errors" @@ -61,7 +62,7 @@ func (r *RegistrationRepository) saveUnderTokens(registration domain.Registratio } for _, relayAddress := range pubKeyWithRelays.Relays() { - relayDocPath := publicKeyDocPath.Collection(collectionAPNSTokensPublicKeysRelays).Doc(relayAddress.String()) + relayDocPath := publicKeyDocPath.Collection(collectionAPNSTokensPublicKeysRelays).Doc(r.relayAddressAsKey(relayAddress)) relayDocData := map[string]any{ "address": relayAddress.String(), } @@ -77,7 +78,7 @@ func (r *RegistrationRepository) saveUnderTokens(registration domain.Registratio func (r *RegistrationRepository) saveUnderRelays(registration domain.Registration) error { for _, pubKeyWithRelays := range registration.PublicKeys() { for _, relayAddress := range pubKeyWithRelays.Relays() { - relayDocPath := r.client.Collection(collectionRelays).Doc(relayAddress.String()) + relayDocPath := r.client.Collection(collectionRelays).Doc(r.relayAddressAsKey(relayAddress)) relayDocData := map[string]any{ "address": relayAddress, } @@ -110,9 +111,9 @@ func (r *RegistrationRepository) GetRelays(ctx context.Context) ([]domain.RelayA return nil, errors.Wrap(err, "error calling iter next") } - relayAddress, err := domain.NewRelayAddress(docRef.Ref.ID) + relayAddress, err := r.relayAddressFromKey(docRef.Ref.ID) if err != nil { - return nil, errors.Wrap(err, "error creating a relay address") + return nil, errors.Wrapf(err, "error creating a relay address from key '%s'", docRef.Ref.ID) } result = append(result, relayAddress) } @@ -121,7 +122,7 @@ func (r *RegistrationRepository) GetRelays(ctx context.Context) ([]domain.RelayA } func (r *RegistrationRepository) GetPublicKeys(ctx context.Context, address domain.RelayAddress) ([]domain.PublicKey, error) { - iter := r.client.Collection(collectionRelays).Doc(address.String()).Collection(collectionRelaysPublicKeys).Documents(ctx) + iter := r.client.Collection(collectionRelays).Doc(r.relayAddressAsKey(address)).Collection(collectionRelaysPublicKeys).Documents(ctx) var result []domain.PublicKey for { @@ -142,3 +143,21 @@ func (r *RegistrationRepository) GetPublicKeys(ctx context.Context, address doma return result, nil } + +func (r *RegistrationRepository) relayAddressAsKey(v domain.RelayAddress) string { + return hex.EncodeToString([]byte(v.String())) +} + +func (r *RegistrationRepository) relayAddressFromKey(v string) (domain.RelayAddress, error) { + b, err := hex.DecodeString(v) + if err != nil { + return domain.RelayAddress{}, errors.Wrap(err, "error decoding relay address from hex") + } + + addr, err := domain.NewRelayAddress(string(b)) + if err != nil { + return domain.RelayAddress{}, errors.Wrap(err, "error creating a relay address") + } + + return addr, nil +} diff --git a/service/app/downloader.go b/service/app/downloader.go new file mode 100644 index 0000000..2d84138 --- /dev/null +++ b/service/app/downloader.go @@ -0,0 +1,268 @@ +package app + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/boreq/errors" + "github.com/gorilla/websocket" + "github.com/nbd-wtf/go-nostr" + "github.com/planetary-social/go-notification-service/internal" + "github.com/planetary-social/go-notification-service/service/domain" +) + +type Downloader struct { + transactionProvider TransactionProvider + relayDownloaders map[domain.RelayAddress]*RelayDownloader +} + +func NewDownloader(transaction TransactionProvider) *Downloader { + return &Downloader{ + transactionProvider: transaction, + relayDownloaders: map[domain.RelayAddress]*RelayDownloader{}, + } +} + +func (d *Downloader) Run(ctx context.Context) error { + for { + relayAddresses, err := d.getRelays(ctx) + if err != nil { + return errors.Wrap(err, "error getting relays") + } + + for relayAddress, relayDownloader := range d.relayDownloaders { + if !relayAddresses.Contains(relayAddress) { + delete(d.relayDownloaders, relayAddress) + relayDownloader.Stop() + } + } + + for _, relayAddress := range relayAddresses.List() { + if _, ok := d.relayDownloaders[relayAddress]; !ok { + relayDownloader := NewRelayDownloader(ctx, d.transactionProvider, relayAddress) + d.relayDownloaders[relayAddress] = relayDownloader + } + } + + <-time.After(60 * time.Second) + } +} + +func (d *Downloader) getRelays(ctx context.Context) (*internal.Set[domain.RelayAddress], error) { + var relays []domain.RelayAddress + + if err := d.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { + tmp, err := adapters.Registrations.GetRelays(ctx) + if err != nil { + return errors.Wrap(err, "error getting relays") + } + relays = tmp + return nil + }); err != nil { + return nil, errors.Wrap(err, "transaction error") + } + + return internal.NewSet(relays), nil +} + +type RelayDownloader struct { + address domain.RelayAddress + transactionProvider TransactionProvider + cancel context.CancelFunc +} + +func NewRelayDownloader(ctx context.Context, transactionProvider TransactionProvider, address domain.RelayAddress) *RelayDownloader { + ctx, cancel := context.WithCancel(ctx) + v := &RelayDownloader{ + transactionProvider: transactionProvider, + cancel: cancel, + address: address, + } + go v.run(ctx) + return v +} + +func (d *RelayDownloader) run(ctx context.Context) { + for { + if err := d.connectAndDownload(ctx); err != nil { + fmt.Printf("error processing relay '%s': %s\n", d.address, err) + } + + select { + case <-ctx.Done(): + return + case <-time.After(10 * time.Second): + continue + } + } +} + +func (d *RelayDownloader) connectAndDownload(ctx context.Context) error { + conn, _, err := websocket.DefaultDialer.DialContext(ctx, d.address.String(), nil) + if err != nil { + return errors.Wrap(err, "error dialing the relay") + } + defer conn.Close() + + activeSubscriptions := internal.NewEmptySet[domain.PublicKey]() + activeSubscriptionsLock := &sync.Mutex{} + + go func() { + if err := d.manageSubs(ctx, conn, activeSubscriptions, activeSubscriptionsLock); err != nil { + fmt.Println("error managing subs", err) + } + }() + + for { + _, messageBytes, err := conn.ReadMessage() + if err != nil { + return errors.Wrap(err, "error reading a message") + } + + if err := d.handleMessage(ctx, messageBytes, activeSubscriptions, activeSubscriptionsLock); err != nil { + return errors.Wrap(err, "error handling message") + } + + } + +} + +func (d *RelayDownloader) handleMessage( + ctx context.Context, + messageBytes []byte, + activeSubscriptions *internal.Set[domain.PublicKey], + activeSubscriptionsLock *sync.Mutex, +) error { + envelope := nostr.ParseMessage(messageBytes) + if envelope == nil { + return errors.New("error parsing message, we are never going to find out what error unfortunately due to the design of this library") + } + + switch v := envelope.(type) { + case *nostr.EOSEEnvelope: + publicKey, err := domain.NewPublicKey(string(*v)) + if err != nil { + return errors.Wrap(err, "invalid public key; unexpected subscription id since we only create them from public keys") + } + + activeSubscriptionsLock.Lock() + activeSubscriptionsLock.Unlock() + activeSubscriptions.Delete(publicKey) + // todo there is a bug here, we may have recreated the sub and this + // message refers to the previous sub + case *nostr.EventEnvelope: + event, err := domain.NewEvent(v.Event) + if err != nil { + return errors.Wrap(err, "error creating an event") + } + + // todo maybe pubsub those events and then handle them later? + if err := d.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { + // todo figure out if we actually want to save this + return adapters.Events.Save(d.address, event) + }); err != nil { + return errors.Wrap(err, "transaction error") + } + default: + fmt.Println("unknown message:", string(messageBytes)) + } + + return nil +} + +func (d *RelayDownloader) manageSubs( + ctx context.Context, + conn *websocket.Conn, + activeSubscriptions *internal.Set[domain.PublicKey], + activeSubscriptionsLock *sync.Mutex, +) error { + defer conn.Close() + + for { + publicKeys, err := d.getPublicKeys(ctx) + if err != nil { + return errors.Wrap(err, "error getting public keys") + } + + if err := d.updateSubs(conn, activeSubscriptions, activeSubscriptionsLock, publicKeys); err != nil { + return errors.Wrap(err, "error updating subscriptions") + } + } +} + +func (d *RelayDownloader) updateSubs( + conn *websocket.Conn, + activeSubscriptions *internal.Set[domain.PublicKey], + activeSubscriptionsLock *sync.Mutex, + publicKeys *internal.Set[domain.PublicKey], +) error { + activeSubscriptionsLock.Lock() + defer activeSubscriptionsLock.Unlock() + + for _, publicKey := range activeSubscriptions.List() { + if !publicKeys.Contains(publicKey) { + envelope := nostr.CloseEnvelope(publicKey.Hex()) + + envelopeJSON, err := envelope.MarshalJSON() + if err != nil { + return errors.Wrap(err, "marshaling close envelope failed") + } + + if err := conn.WriteMessage(websocket.TextMessage, envelopeJSON); err != nil { + return errors.Wrap(err, "writing close envelope error") + } + + activeSubscriptions.Delete(publicKey) + } + } + + for _, publicKey := range publicKeys.List() { + if ok := activeSubscriptions.Contains(publicKey); !ok { + envelope := nostr.ReqEnvelope{ + SubscriptionID: publicKey.Hex(), + Filters: nostr.Filters{nostr.Filter{ + Authors: []string{ + publicKey.Hex(), + }, + Since: nil, // todo filter based on already received events + }}, + } + + envelopeJSON, err := envelope.MarshalJSON() + if err != nil { + return errors.Wrap(err, "marshaling req envelope failed") + } + + if err := conn.WriteMessage(websocket.TextMessage, envelopeJSON); err != nil { + return errors.Wrap(err, "writing req envelope error") + } + + activeSubscriptions.Put(publicKey) + } + } + + return nil +} + +func (d *RelayDownloader) getPublicKeys(ctx context.Context) (*internal.Set[domain.PublicKey], error) { + var publicKeys []domain.PublicKey + + if err := d.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { + tmp, err := adapters.Registrations.GetPublicKeys(ctx, d.address) + if err != nil { + return errors.Wrap(err, "error getting public keys") + } + publicKeys = tmp + return nil + }); err != nil { + return nil, errors.Wrap(err, "transaction error") + } + + return internal.NewSet(publicKeys), nil +} + +func (d RelayDownloader) Stop() { + d.cancel() +} diff --git a/service/app/handler_get_public_keys.go b/service/app/handler_get_public_keys.go new file mode 100644 index 0000000..3c9ab23 --- /dev/null +++ b/service/app/handler_get_public_keys.go @@ -0,0 +1,35 @@ +package app + +import ( + "context" + + "github.com/boreq/errors" + "github.com/planetary-social/go-notification-service/service/domain" +) + +type GetPublicKeysHandler struct { + transactionProvider TransactionProvider +} + +func NewGetPublicKeysHandler( + transactionProvider TransactionProvider, +) *GetPublicKeysHandler { + return &GetPublicKeysHandler{ + transactionProvider: transactionProvider, + } +} + +func (h *GetPublicKeysHandler) Handle(ctx context.Context, relay domain.RelayAddress) ([]domain.PublicKey, error) { + var result []domain.PublicKey + if err := h.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { + tmp, err := adapters.Registrations.GetPublicKeys(ctx, relay) + if err != nil { + return errors.Wrap(err, "error getting relays") + } + result = tmp + return nil + }); err != nil { + return nil, errors.Wrap(err, "transaction error") + } + return result, nil +} diff --git a/service/app/handler_get_relays.go b/service/app/handler_get_relays.go new file mode 100644 index 0000000..dd8455d --- /dev/null +++ b/service/app/handler_get_relays.go @@ -0,0 +1,35 @@ +package app + +import ( + "context" + + "github.com/boreq/errors" + "github.com/planetary-social/go-notification-service/service/domain" +) + +type GetRelaysHandler struct { + transactionProvider TransactionProvider +} + +func NewGetRelaysHandler( + transactionProvider TransactionProvider, +) *GetRelaysHandler { + return &GetRelaysHandler{ + transactionProvider: transactionProvider, + } +} + +func (h *GetRelaysHandler) Handle(ctx context.Context) ([]domain.RelayAddress, error) { + var result []domain.RelayAddress + if err := h.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { + tmp, err := adapters.Registrations.GetRelays(ctx) + if err != nil { + return errors.Wrap(err, "error getting relays") + } + result = tmp + return nil + }); err != nil { + return nil, errors.Wrap(err, "transaction error") + } + return result, nil +} diff --git a/service/domain/event.go b/service/domain/event.go index aaa40b1..51f7de8 100644 --- a/service/domain/event.go +++ b/service/domain/event.go @@ -12,8 +12,8 @@ type Event struct { content []byte } -func NewEventFromEnvelope(envelope nostr.EventEnvelope) (Event, error) { - ok, err := envelope.CheckSignature() +func NewEvent(libevent nostr.Event) (Event, error) { + ok, err := libevent.CheckSignature() if err != nil { return Event{}, errors.Wrap(err, "error checking signature") } @@ -22,14 +22,14 @@ func NewEventFromEnvelope(envelope nostr.EventEnvelope) (Event, error) { return Event{}, errors.New("invalid signature") } - pubKey, err := NewPublicKey(envelope.PubKey) + pubKey, err := NewPublicKey(libevent.PubKey) if err != nil { return Event{}, errors.Wrap(err, "error creating a pub key") } return Event{ pubKey: pubKey, - content: []byte(envelope.Content), + content: []byte(libevent.Content), }, nil } diff --git a/service/domain/public_key_test.go b/service/domain/public_key_test.go new file mode 100644 index 0000000..dd6709d --- /dev/null +++ b/service/domain/public_key_test.go @@ -0,0 +1,17 @@ +package domain + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPublicKey_IsCaseInsensitive(t *testing.T) { + a, err := NewPublicKey("ABCD") + require.NoError(t, err) + + b, err := NewPublicKey("abcd") + require.NoError(t, err) + + require.Equal(t, a, b) +} diff --git a/service/domain/registration.go b/service/domain/registration.go index 99f060f..91c7f7d 100644 --- a/service/domain/registration.go +++ b/service/domain/registration.go @@ -2,6 +2,7 @@ package domain import ( "encoding/json" + "strings" "github.com/boreq/errors" "github.com/planetary-social/go-notification-service/internal" @@ -111,8 +112,11 @@ type RelayAddress struct { } func NewRelayAddress(s string) (RelayAddress, error) { - // todo validate + if !strings.HasPrefix(s, "ws://") && !strings.HasPrefix(s, "wss://") { + return RelayAddress{}, errors.New("invalid protocol") + } + // todo validate return RelayAddress{s: s}, nil } diff --git a/service/ports/http/http.go b/service/ports/http/http.go index 2c1d2f1..faaf844 100644 --- a/service/ports/http/http.go +++ b/service/ports/http/http.go @@ -92,7 +92,7 @@ func (s *Server) handleConnection(ctx context.Context, conn *websocket.Conn) err switch v := message.(type) { case *nostr.EventEnvelope: - event, err := domain.NewEventFromEnvelope(*v) + event, err := domain.NewEvent(v.Event) if err != nil { return errors.Wrap(err, "error creating an event") }