diff --git a/internal/broker/service.go b/internal/broker/service.go index 8f1ad6c7..92d6004b 100644 --- a/internal/broker/service.go +++ b/internal/broker/service.go @@ -157,11 +157,7 @@ func NewService(ctx context.Context, cfg *config.Config) (s *Service, err error) s.surveyor = survey.New(s.pubsub, s.cluster) s.presence = presence.New(s, s.pubsub, s.surveyor, s.subscriptions) if s.cluster != nil { - if s.storage.Name() == ssdstore.Name() { - s.surveyor.HandleFunc(s.presence, ssdstore) - } else if s.storage.Name() == memstore.Name() { - s.surveyor.HandleFunc(s.presence, memstore) - } + s.surveyor.HandleFunc(s.storage) } // Create a new cipher from the licence provided diff --git a/internal/provider/storage/memory_test.go b/internal/provider/storage/memory_test.go index 062e1add..a390cfde 100644 --- a/internal/provider/storage/memory_test.go +++ b/internal/provider/storage/memory_test.go @@ -117,7 +117,7 @@ func TestInMemory_Query(t *testing.T) { if tc.gathered == nil { s.survey = nil } else { - s.survey = survey(func(string, []byte) (message.Awaiter, error) { + s.survey = surveyFunc(func(string, []byte) (message.Awaiter, error) { return &mockAwaiter{f: func(_ time.Duration) [][]byte { return [][]byte{tc.gathered} }}, nil }) } diff --git a/internal/provider/storage/ssd_test.go b/internal/provider/storage/ssd_test.go index 8a42dc7f..a7b61ebd 100644 --- a/internal/provider/storage/ssd_test.go +++ b/internal/provider/storage/ssd_test.go @@ -128,7 +128,7 @@ func TestSSD_QuerySurveyed(t *testing.T) { if tc.gathered == nil { s.survey = nil } else { - s.survey = survey(func(string, []byte) (message.Awaiter, error) { + s.survey = surveyFunc(func(string, []byte) (message.Awaiter, error) { return &mockAwaiter{f: func(_ time.Duration) [][]byte { return [][]byte{tc.gathered} }}, nil }) } diff --git a/internal/provider/storage/storage.go b/internal/provider/storage/storage.go index e3689195..ac76d0a7 100644 --- a/internal/provider/storage/storage.go +++ b/internal/provider/storage/storage.go @@ -22,6 +22,7 @@ import ( "github.com/emitter-io/config" "github.com/emitter-io/emitter/internal/message" "github.com/emitter-io/emitter/internal/security" + "github.com/emitter-io/emitter/internal/service/survey" ) var ( @@ -37,6 +38,7 @@ const ( type Storage interface { config.Provider io.Closer + survey.Surveyee // Store is used to store a message, the SSID provided must be a full SSID // SSID, where first element should be a contract ID. The time resolution @@ -139,3 +141,8 @@ func (s *Noop) Query(ssid message.Ssid, from, until time.Time, startFromID messa func (s *Noop) Close() error { return nil } + +// OnSurvey handles an incoming cluster lookup request. +func (s *Noop) OnSurvey(surveyType string, payload []byte) ([]byte, bool) { + return []byte{}, true +} diff --git a/internal/provider/storage/storage_test.go b/internal/provider/storage/storage_test.go index 730979d5..d59c42ba 100644 --- a/internal/provider/storage/storage_test.go +++ b/internal/provider/storage/storage_test.go @@ -24,9 +24,9 @@ import ( "github.com/stretchr/testify/assert" ) -type survey func(string, []byte) (message.Awaiter, error) +type surveyFunc func(string, []byte) (message.Awaiter, error) -func (s survey) Query(q string, b []byte) (message.Awaiter, error) { +func (s surveyFunc) Query(q string, b []byte) (message.Awaiter, error) { return s(q, b) } diff --git a/internal/service/pubsub/subscribe_test.go b/internal/service/pubsub/subscribe_test.go index 06d394ad..3aeec2a3 100644 --- a/internal/service/pubsub/subscribe_test.go +++ b/internal/service/pubsub/subscribe_test.go @@ -184,3 +184,7 @@ func (s *buggyStore) Query(ssid message.Ssid, from, until time.Time, startFromID func (s *buggyStore) Close() error { return errors.New("not working") } + +func (s *buggyStore) OnSurvey(surveyType string, payload []byte) ([]byte, bool) { + return []byte{}, true +}