diff --git a/bolt_transport_test.go b/bolt_transport_test.go index ac92877a..0e31a364 100644 --- a/bolt_transport_test.go +++ b/bolt_transport_test.go @@ -35,7 +35,7 @@ func TestBoltTransportHistory(t *testing.T) { }) } - s := NewSubscriber("8", transport.logger) + s := NewSubscriber("8", transport.logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -68,7 +68,7 @@ func TestBoltTransportLogsBogusLastEventID(t *testing.T) { Topics: topics, }) - s := NewSubscriber("711131", logger) + s := NewSubscriber("711131", logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -87,7 +87,7 @@ func TestBoltTopicSelectorHistory(t *testing.T) { transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Private: true, Event: Event{ID: "3"}}) transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Event: Event{ID: "4"}}) - s := NewSubscriber(EarliestLastEventID, transport.logger) + s := NewSubscriber(EarliestLastEventID, transport.logger, &TopicSelectorStore{}) s.SetTopics([]string{"http://example.com/subscribed", "http://example.com/subscribed-public-only"}, []string{"http://example.com/subscribed"}) require.NoError(t, transport.AddSubscriber(s)) @@ -109,7 +109,7 @@ func TestBoltTransportRetrieveAllHistory(t *testing.T) { }) } - s := NewSubscriber(EarliestLastEventID, transport.logger) + s := NewSubscriber(EarliestLastEventID, transport.logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -139,7 +139,7 @@ func TestBoltTransportHistoryAndLive(t *testing.T) { }) } - s := NewSubscriber("8", transport.logger) + s := NewSubscriber("8", transport.logger, &TopicSelectorStore{}) s.SetTopics(topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -221,7 +221,7 @@ func TestBoltTransportDoNotDispatchUntilListen(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", transport.logger) + s := NewSubscriber("", transport.logger, &TopicSelectorStore{}) require.NoError(t, transport.AddSubscriber(s)) var wg sync.WaitGroup @@ -245,7 +245,7 @@ func TestBoltTransportDispatch(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", transport.logger) + s := NewSubscriber("", transport.logger, &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com/foo", "https://example.com/private"}, []string{"https://example.com/private"}) require.NoError(t, transport.AddSubscriber(s)) @@ -274,7 +274,7 @@ func TestBoltTransportClosed(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", transport.logger) + s := NewSubscriber("", transport.logger, &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com/foo"}, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -293,11 +293,11 @@ func TestBoltCleanDisconnectedSubscribers(t *testing.T) { defer transport.Close() defer os.Remove("test.db") - s1 := NewSubscriber("", transport.logger) + s1 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) s1.SetTopics([]string{"foo"}, []string{}) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", transport.logger) + s2 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) s2.SetTopics([]string{"foo"}, []string{}) require.NoError(t, transport.AddSubscriber(s2)) @@ -318,10 +318,10 @@ func TestBoltGetSubscribers(t *testing.T) { defer transport.Close() defer os.Remove("test.db") - s1 := NewSubscriber("", transport.logger) + s1 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", transport.logger) + s2 := NewSubscriber("", transport.logger, &TopicSelectorStore{}) require.NoError(t, transport.AddSubscriber(s2)) lastEventID, subscribers, err := transport.GetSubscribers() diff --git a/local_transport_bench_test.go b/local_transport_bench_test.go index e38cd699..78984356 100644 --- a/local_transport_bench_test.go +++ b/local_transport_bench_test.go @@ -39,8 +39,9 @@ func subBenchLocalTransport(b *testing.B, topics, concurrency, matchPct int, tes } } out := make(chan *Update, 50000) + tss := &TopicSelectorStore{} for i := 0; i < concurrency; i++ { - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", zap.NewNop(), tss) if i%100 < matchPct { s.SetTopics(tsMatch, nil) } else { diff --git a/local_transport_test.go b/local_transport_test.go index 14ba3721..964c5ec7 100644 --- a/local_transport_test.go +++ b/local_transport_test.go @@ -11,7 +11,8 @@ import ( ) func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { - transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, zap.NewNop()) + logger := zap.NewNop() + transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, logger) defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) @@ -19,7 +20,7 @@ func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { err := transport.Dispatch(u) require.NoError(t, err) - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", logger, &TopicSelectorStore{}) s.SetTopics(u.Topics, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -37,11 +38,12 @@ func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { } func TestLocalTransportDispatch(t *testing.T) { - transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, zap.NewNop()) + logger := zap.NewNop() + transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, logger) defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", logger, &TopicSelectorStore{}) s.SetTopics([]string{"http://example.com/foo"}, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -51,14 +53,17 @@ func TestLocalTransportDispatch(t *testing.T) { } func TestLocalTransportClosed(t *testing.T) { - transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, zap.NewNop()) + logger := zap.NewNop() + transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, logger) defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", zap.NewNop()) + tss := &TopicSelectorStore{} + + s := NewSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s)) require.NoError(t, transport.Close()) - assert.Equal(t, transport.AddSubscriber(NewSubscriber("", zap.NewNop())), ErrClosedTransport) + assert.Equal(t, transport.AddSubscriber(NewSubscriber("", logger, tss)), ErrClosedTransport) assert.Equal(t, transport.Dispatch(&Update{}), ErrClosedTransport) _, ok := <-s.out @@ -66,14 +71,17 @@ func TestLocalTransportClosed(t *testing.T) { } func TestLiveCleanDisconnectedSubscribers(t *testing.T) { - tr, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, zap.NewNop()) + logger := zap.NewNop() + tr, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, logger) transport := tr.(*LocalTransport) defer transport.Close() - s1 := NewSubscriber("", zap.NewNop()) + tss := &TopicSelectorStore{} + + s1 := NewSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", zap.NewNop()) + s2 := NewSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s2)) assert.Equal(t, 2, transport.subscribers.Len()) @@ -88,11 +96,12 @@ func TestLiveCleanDisconnectedSubscribers(t *testing.T) { } func TestLiveReading(t *testing.T) { - transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, zap.NewNop()) + logger := zap.NewNop() + transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, logger) defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", logger, &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com"}, nil) require.NoError(t, transport.AddSubscriber(s)) @@ -104,14 +113,17 @@ func TestLiveReading(t *testing.T) { } func TestLocalTransportGetSubscribers(t *testing.T) { - transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, zap.NewNop()) + logger := zap.NewNop() + transport, _ := DeprecatedNewLocalTransport(&url.URL{Scheme: "local"}, logger) defer transport.Close() require.NotNil(t, transport) - s1 := NewSubscriber("", zap.NewNop()) + tss := &TopicSelectorStore{} + + s1 := NewSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s1)) - s2 := NewSubscriber("", zap.NewNop()) + s2 := NewSubscriber("", logger, tss) require.NoError(t, transport.AddSubscriber(s2)) lastEventID, subscribers, err := transport.(TransportSubscribers).GetSubscribers() diff --git a/metrics_test.go b/metrics_test.go index ca183b34..88e82b8c 100644 --- a/metrics_test.go +++ b/metrics_test.go @@ -12,12 +12,15 @@ import ( func TestNumberOfRunningSubscribers(t *testing.T) { m := NewPrometheusMetrics(nil) - s1 := NewSubscriber("", zap.NewNop()) + logger := zap.NewNop() + tss := &TopicSelectorStore{} + + s1 := NewSubscriber("", logger, tss) s1.SetTopics([]string{"topic1", "topic2"}, nil) m.SubscriberConnected(s1) assertGaugeValue(t, 1.0, m.subscribers) - s2 := NewSubscriber("", zap.NewNop()) + s2 := NewSubscriber("", logger, tss) s2.SetTopics([]string{"topic2"}, nil) m.SubscriberConnected(s2) assertGaugeValue(t, 2.0, m.subscribers) @@ -32,12 +35,15 @@ func TestNumberOfRunningSubscribers(t *testing.T) { func TestTotalNumberOfHandledSubscribers(t *testing.T) { m := NewPrometheusMetrics(nil) - s1 := NewSubscriber("", zap.NewNop()) + logger := zap.NewNop() + tss := &TopicSelectorStore{} + + s1 := NewSubscriber("", logger, tss) s1.SetTopics([]string{"topic1", "topic2"}, nil) m.SubscriberConnected(s1) assertCounterValue(t, 1.0, m.subscribersTotal) - s2 := NewSubscriber("", zap.NewNop()) + s2 := NewSubscriber("", logger, tss) s2.SetTopics([]string{"topic2"}, nil) m.SubscriberConnected(s2) assertCounterValue(t, 2.0, m.subscribersTotal) diff --git a/publish_test.go b/publish_test.go index 38da318f..92dc2d16 100644 --- a/publish_test.go +++ b/publish_test.go @@ -174,7 +174,7 @@ func TestPublishOK(t *testing.T) { hub := createDummy() topics := []string{"http://example.com/books/1"} - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.SetTopics(topics, topics) s.Claims = &claims{Mercure: mercureClaim{Subscribe: topics}} @@ -238,7 +238,7 @@ func TestPublishNoData(t *testing.T) { func TestPublishGenerateUUID(t *testing.T) { h := createDummy() - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.SetTopics([]string{"http://example.com/books/1"}, s.SubscribedTopics) require.NoError(t, h.transport.AddSubscriber(s)) diff --git a/subscribe.go b/subscribe.go index 6ee9dd7e..476e83b9 100644 --- a/subscribe.go +++ b/subscribe.go @@ -156,7 +156,8 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { // registerSubscriber initializes the connection. func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*Subscriber, *responseController) { - s := NewSubscriber(retrieveLastEventID(r, h.opt, h.logger), h.logger) + s := NewSubscriber(retrieveLastEventID(r, h.opt, h.logger), h.logger, &TopicSelectorStore{}) + s.topicSelectorStore = h.topicSelectorStore s.Debug = h.debug s.RemoteAddr = r.RemoteAddr var privateTopics []string diff --git a/subscribe_test.go b/subscribe_test.go index ba39c2f3..24083513 100644 --- a/subscribe_test.go +++ b/subscribe_test.go @@ -3,6 +3,7 @@ package mercure import ( "context" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -56,8 +57,13 @@ func (rt *responseTester) Write(buf []byte) (int, error) { if rt.body == rt.expectedBody { rt.cancel() } else if !strings.HasPrefix(rt.expectedBody, rt.body) { - rt.t.Errorf(`Received body "%s" doesn't match expected body "%s"`, rt.body, rt.expectedBody) - rt.cancel() + defer rt.cancel() + + mess := fmt.Sprintf(`Received body "%s" doesn't match expected body "%s"`, rt.body, rt.expectedBody) + if rt.t == nil { + panic(mess) + } + rt.t.Error(mess) } return len(buf), nil diff --git a/subscriber.go b/subscriber.go index 8edade9f..b22ac03b 100644 --- a/subscriber.go +++ b/subscriber.go @@ -8,7 +8,6 @@ import ( "sync/atomic" "github.com/gofrs/uuid" - uritemplate "github.com/yosida95/uritemplate/v3" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -35,12 +34,13 @@ type Subscriber struct { ready int32 liveQueue []*Update liveMutex sync.RWMutex + topicSelectorStore *TopicSelectorStore } const outBufferLength = 1000 // NewSubscriber creates a new subscriber. -func NewSubscriber(lastEventID string, logger Logger) *Subscriber { +func NewSubscriber(lastEventID string, logger Logger, topicSelectorStore *TopicSelectorStore) *Subscriber { id := "urn:uuid:" + uuid.Must(uuid.NewV4()).String() s := &Subscriber{ ID: id, @@ -49,6 +49,7 @@ func NewSubscriber(lastEventID string, logger Logger) *Subscriber { responseLastEventID: make(chan string, 1), out: make(chan *Update, outBufferLength), logger: logger, + topicSelectorStore: topicSelectorStore, } return s @@ -142,23 +143,7 @@ func (s *Subscriber) Disconnect() { // SetTopics compiles topic selector regexps. func (s *Subscriber) SetTopics(subscribedTopics, allowedPrivateTopics []string) { s.SubscribedTopics = subscribedTopics - s.SubscribedTopicRegexps = make([]*regexp.Regexp, len(subscribedTopics)) - for i, ts := range subscribedTopics { - var r *regexp.Regexp - if tpl, err := uritemplate.New(ts); err == nil { - r = tpl.Regexp() - } - s.SubscribedTopicRegexps[i] = r - } s.AllowedPrivateTopics = allowedPrivateTopics - s.AllowedPrivateRegexps = make([]*regexp.Regexp, len(allowedPrivateTopics)) - for i, ts := range allowedPrivateTopics { - var r *regexp.Regexp - if tpl, err := uritemplate.New(ts); err == nil { - r = tpl.Regexp() - } - s.AllowedPrivateRegexps[i] = r - } s.EscapedTopics = escapeTopics(subscribedTopics) } @@ -180,15 +165,8 @@ func (s *Subscriber) MatchTopics(topics []string, private bool) bool { for _, topic := range topics { if !subscribed { - for i, ts := range s.SubscribedTopics { - if ts == "*" || ts == topic { - subscribed = true - - break - } - - r := s.SubscribedTopicRegexps[i] - if r != nil && r.MatchString(topic) { + for _, ts := range s.SubscribedTopics { + if s.topicSelectorStore.match(topic, ts) { subscribed = true break @@ -197,28 +175,17 @@ func (s *Subscriber) MatchTopics(topics []string, private bool) bool { } if !canAccess { - for i, ts := range s.AllowedPrivateTopics { - if ts == "*" || ts == topic { - canAccess = true - - break - } - - r := s.AllowedPrivateRegexps[i] - if r != nil && r.MatchString(topic) { + for _, ts := range s.AllowedPrivateTopics { + if s.topicSelectorStore.match(topic, ts) { canAccess = true break } } } - - if subscribed && canAccess { - return true - } } - return false + return subscribed && canAccess } // Match checks if the current subscriber can receive the given update. diff --git a/subscriber_bench_test.go b/subscriber_bench_test.go index 3659a802..18cfd5d7 100644 --- a/subscriber_bench_test.go +++ b/subscriber_bench_test.go @@ -85,7 +85,7 @@ func strInt(s string) int { func subBenchSubscriber(b *testing.B, topics, concurrency, matchPct int, testName string) { b.Helper() - s := NewSubscriber("0e249241-6432-4ce1-b9b9-5d170163c253", zap.NewNop()) + s := NewSubscriber("0e249241-6432-4ce1-b9b9-5d170163c253", zap.NewNop(), &TopicSelectorStore{}) ts := make([]string, topics) tsMatch := make([]string, topics) tsNoMatch := make([]string, topics) diff --git a/subscriber_list_test.go b/subscriber_list_test.go index a52f6c76..a9e69834 100644 --- a/subscriber_list_test.go +++ b/subscriber_list_test.go @@ -22,10 +22,11 @@ func TestDecode(t *testing.T) { func BenchmarkSubscriberList(b *testing.B) { logger := zap.NewNop() + tss := &TopicSelectorStore{} l := NewSubscriberList(100) for i := 0; i < 100; i++ { - s := NewSubscriber("", logger) + s := NewSubscriber("", logger, tss) t := fmt.Sprintf("https://example.com/%d", (i % 10)) s.SetTopics([]string{"https://example.org/foo", t}, []string{"https://example.net/bar", t}) diff --git a/subscriber_test.go b/subscriber_test.go index 1432ada9..7c324b5f 100644 --- a/subscriber_test.go +++ b/subscriber_test.go @@ -9,7 +9,7 @@ import ( ) func TestDispatch(t *testing.T) { - s := NewSubscriber("1", zap.NewNop()) + s := NewSubscriber("1", zap.NewNop(), &TopicSelectorStore{}) s.SubscribedTopics = []string{"http://example.com"} s.SubscribedTopics = []string{"http://example.com"} defer s.Disconnect() @@ -32,7 +32,7 @@ func TestDispatch(t *testing.T) { } func TestDisconnect(t *testing.T) { - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.Disconnect() // can be called two times without crashing s.Disconnect() @@ -44,7 +44,7 @@ func TestLogSubscriber(t *testing.T) { sink, logger := newTestLogger(t) defer sink.Reset() - s := NewSubscriber("123", logger) + s := NewSubscriber("123", logger, &TopicSelectorStore{}) s.RemoteAddr = "127.0.0.1" s.SetTopics([]string{"https://example.com/bar"}, []string{"https://example.com/foo"}) @@ -59,7 +59,7 @@ func TestLogSubscriber(t *testing.T) { } func TestMatchTopic(t *testing.T) { - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.SetTopics([]string{"https://example.com/no-match", "https://example.com/books/{id}"}, []string{"https://example.com/users/foo/{?topic}"}) assert.False(t, s.Match(&Update{Topics: []string{"https://example.com/not-subscribed"}})) @@ -73,7 +73,7 @@ func TestMatchTopic(t *testing.T) { } func TestSubscriberDoesNotBlockWhenChanIsFull(t *testing.T) { - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", zap.NewNop(), &TopicSelectorStore{}) s.Ready() for i := 0; i <= outBufferLength; i++ { diff --git a/subscription_test.go b/subscription_test.go index b832265b..13363b8a 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -82,13 +82,16 @@ func TestSubscriptionHandlersETag(t *testing.T) { } func TestSubscriptionsHandler(t *testing.T) { - hub := createDummy() + logger := zap.NewNop() + + hub := createDummy(WithLogger(logger)) + tss := &TopicSelectorStore{} - s1 := NewSubscriber("", zap.NewNop()) + s1 := NewSubscriber("", logger, tss) s1.SetTopics([]string{"http://example.com/foo"}, nil) require.NoError(t, hub.transport.AddSubscriber(s1)) - s2 := NewSubscriber("", zap.NewNop()) + s2 := NewSubscriber("", logger, tss) s2.SetTopics([]string{"http://example.com/bar"}, nil) require.NoError(t, hub.transport.AddSubscriber(s2)) @@ -121,13 +124,15 @@ func TestSubscriptionsHandler(t *testing.T) { } func TestSubscriptionsHandlerForTopic(t *testing.T) { - hub := createDummy() + logger := zap.NewNop() + hub := createDummy(WithLogger(logger)) + tss := &TopicSelectorStore{} - s1 := NewSubscriber("", zap.NewNop()) + s1 := NewSubscriber("", logger, tss) s1.SetTopics([]string{"http://example.com/foo"}, nil) require.NoError(t, hub.transport.AddSubscriber(s1)) - s2 := NewSubscriber("", zap.NewNop()) + s2 := NewSubscriber("", logger, tss) s2.SetTopics([]string{"http://example.com/bar"}, nil) require.NoError(t, hub.transport.AddSubscriber(s2)) @@ -166,13 +171,15 @@ func TestSubscriptionsHandlerForTopic(t *testing.T) { } func TestSubscriptionHandler(t *testing.T) { - hub := createDummy() + logger := zap.NewNop() + hub := createDummy(WithLogger(logger)) + tss := &TopicSelectorStore{} - otherS := NewSubscriber("", zap.NewNop()) + otherS := NewSubscriber("", logger, tss) otherS.SetTopics([]string{"http://example.com/other"}, nil) require.NoError(t, hub.transport.AddSubscriber(otherS)) - s := NewSubscriber("", zap.NewNop()) + s := NewSubscriber("", logger, tss) s.SetTopics([]string{"http://example.com/other", "http://example.com/{foo}"}, nil) require.NoError(t, hub.transport.AddSubscriber(s))