Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: payload for a specific topic #945

Merged
merged 3 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ linters:

# deprecated
- execinquery
- exportloopref

issues:
exclude-rules:
Expand Down
8 changes: 5 additions & 3 deletions authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ type claims struct {
}

type mercureClaim struct {
Publish []string `json:"publish"`
Subscribe []string `json:"subscribe"`
Payload interface{} `json:"payload"`
Publish []string `json:"publish"`
Subscribe []string `json:"subscribe"`
// Deprecated: use the Payloads field instead
Payload interface{} `json:"payload"`
Payloads map[string]interface{} `json:"payloads"`
}

type role int
Expand Down
11 changes: 6 additions & 5 deletions hub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,13 @@ func createAnonymousDummy(options ...Option) *Hub {
}

func createDummyAuthorizedJWT(r role, topics []string) string {
return createDummyAuthorizedJWTWithPayload(r, topics, struct {
Foo string `json:"foo"`
}{Foo: "bar"})
payloads := map[string]interface{}{"*": make(map[string]string)}
payloads["*"].(map[string]string)["foo"] = "bar"

return createDummyAuthorizedJWTWithPayload(r, topics, payloads)
}

func createDummyAuthorizedJWTWithPayload(r role, topics []string, payload interface{}) string {
func createDummyAuthorizedJWTWithPayload(r role, topics []string, payloads map[string]interface{}) string {
token := jwt.New(jwt.SigningMethodHS256)

var key []byte
Expand All @@ -282,7 +283,7 @@ func createDummyAuthorizedJWTWithPayload(r role, topics []string, payload interf
token.Claims = &claims{
Mercure: mercureClaim{
Subscribe: topics,
Payload: payload,
Payloads: payloads,
},
RegisteredClaims: jwt.RegisteredClaims{},
}
Expand Down
46 changes: 38 additions & 8 deletions subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*Subsc

topics := r.URL.Query()["topic"]
if len(topics) == 0 {
http.Error(w, "Missing \"topic\" parameter.", http.StatusBadRequest)
http.Error(w, `Missing "topic" parameter.`, http.StatusBadRequest)

return nil, nil
}
Expand All @@ -203,17 +203,47 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*Subsc
rc := h.newResponseController(w, s)
rc.flush()

if c := h.logger.Check(zap.InfoLevel, "New subscriber"); c != nil {
fields := []LogField{zap.Object("subscriber", s)}
if claims != nil && h.logger.Level() == zap.DebugLevel {
fields = append(fields, zap.Reflect("payload", claims.Mercure.Payload))
h.normalizeClaims(claims)
h.logNewSubscriber(claims, s)
h.metrics.SubscriberConnected(s)

return s, rc
}

func (h *Hub) logNewSubscriber(claims *claims, s *Subscriber) {
c := h.logger.Check(zap.InfoLevel, "New subscriber")
if c == nil {
return
}

fields := []LogField{zap.Object("subscriber", s)}
if claims != nil && h.logger.Level() == zap.DebugLevel {
if claims.Mercure.Payload != nil && h.opt.isBackwardCompatiblyEnabledWith(8) {
fields = append(
fields,
zap.Reflect("payload", claims.Mercure.Payload),
)
}

c.Write(fields...)
fields = append(
fields,
zap.Reflect("payloads", claims.Mercure.Payloads),
)
}
h.metrics.SubscriberConnected(s)

return s, rc
c.Write(fields...)
}

func (h *Hub) normalizeClaims(c *claims) {
if c == nil || c.Mercure.Payload == nil {
return
}

if h.opt.isBackwardCompatiblyEnabledWith(8) {
h.logger.Info(`Deprecated: the "mercure.payload" JWT claim deprecated since the version 8 of the protocol, use "mercure.payloads" claim with a "*" key instead.`)
} else {
c.Mercure.Payload = nil
}
}

// sendHeaders sends correct HTTP headers to create a keep-alive connection.
Expand Down
18 changes: 9 additions & 9 deletions subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,12 @@ func TestSubscribe(t *testing.T) {
testSubscribe(t, 3)
}

func testSubscribeLogs(t *testing.T, hub *Hub, payload interface{}) {
func testSubscribeLogs(t *testing.T, hub *Hub, payloads map[string]interface{}) {
t.Helper()

ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest(http.MethodGet, defaultHubURL+"?topic=http://example.com/reviews/{id}", nil).WithContext(ctx)
req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWTWithPayload(roleSubscriber, []string{"http://example.com/reviews/22"}, payload)})
req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWTWithPayload(roleSubscriber, []string{"http://example.com/reviews/22"}, payloads)})

w := &responseTester{
expectedStatusCode: http.StatusOK,
Expand All @@ -351,18 +351,18 @@ func testSubscribeLogs(t *testing.T, hub *Hub, payload interface{}) {

func TestSubscribeWithLogLevelDebug(t *testing.T) {
core, logs := observer.New(zapcore.DebugLevel)
payload := map[string]interface{}{
"bar": "baz",
"foo": "bar",
payloads := map[string]interface{}{
"*": make(map[string]string),
}

payloads["*"].(map[string]string)["bar"] = "baz"
payloads["*"].(map[string]string)["foo"] = "bar"

testSubscribeLogs(t, createDummy(
WithLogger(zap.New(core)),
), payload)
), payloads)

assert.Equal(t, 1, logs.FilterMessage("New subscriber").FilterField(
zap.Reflect("payload", payload)).Len(),
)
assert.Equal(t, 1, logs.FilterMessage("New subscriber").FilterFieldKey("payloads").Len())
}

func TestSubscribeLogLevelInfo(t *testing.T) {
Expand Down
17 changes: 15 additions & 2 deletions subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,21 @@ func (s *Subscriber) getSubscriptions(topic, context string, active bool) []subs
Topic: t,
Active: active,
}
if s.Claims != nil && s.Claims.Mercure.Payload != nil {
subscription.Payload = s.Claims.Mercure.Payload
if s.Claims != nil { //nolint:nestif
if s.Claims.Mercure.Payloads == nil {
if s.Claims.Mercure.Payload != nil {
subscription.Payload = s.Claims.Mercure.Payload
}
} else {
for k, v := range s.Claims.Mercure.Payloads {
if !s.topicSelectorStore.match(t, k) {
continue
}
subscription.Payload = v

break
}
}
}

subscriptions = append(subscriptions, subscription)
Expand Down
35 changes: 35 additions & 0 deletions subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,38 @@ func TestSubscriptionHandler(t *testing.T) {
assert.Equal(t, http.StatusNotFound, res.StatusCode)
res.Body.Close()
}

func TestSubscriptionPayload(t *testing.T) {
logger := zap.NewNop()
tss := &TopicSelectorStore{}

for _, selector := range []string{"*", "http://example.com/foo", "http://example.com/{var}"} {
t.Run("selector "+selector, func(t *testing.T) {
hub := createDummy(WithLogger(logger))

s1 := NewSubscriber("", logger, tss)
s1.SetTopics([]string{"http://example.com/foo"}, nil)

s1.Claims = &claims{}
s1.Claims.Mercure.Payloads = map[string]interface{}{}
s1.Claims.Mercure.Payloads[selector] = "foo"
s1.Claims.Mercure.Payloads["http://example.com/bar"] = "bar"

require.NoError(t, hub.transport.AddSubscriber(s1))

req := httptest.NewRequest(http.MethodGet, defaultHubURL+subscriptionsPath, nil)
req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(roleSubscriber, []string{"/.well-known/mercure/subscriptions"})})
w := httptest.NewRecorder()
hub.SubscriptionsHandler(w, req)
res := w.Result()
assert.Equal(t, http.StatusOK, res.StatusCode)
res.Body.Close()

var subscriptions subscriptionCollection
json.Unmarshal(w.Body.Bytes(), &subscriptions)

require.Len(t, subscriptions.Subscriptions, 1)
assert.Equal(t, "foo", subscriptions.Subscriptions[0].Payload)
})
}
}
Loading