Skip to content

Commit

Permalink
Merge pull request #355 from matrix-org/dmr/cancel-ctx-on-conn-shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson authored Oct 27, 2023
2 parents 5a03b15 + db0fde8 commit 759146e
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 7 deletions.
5 changes: 5 additions & 0 deletions sync3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type ConnHandler interface {
PublishEventsUpTo(roomID string, nid int64)
Destroy()
Alive() bool
SetCancelCallback(cancel context.CancelFunc)
}

// Conn is an abstraction of a long-poll connection. It automatically handles the position values
Expand Down Expand Up @@ -245,3 +246,7 @@ func (c *Conn) OnIncomingRequest(ctx context.Context, req *Request, start time.T
// return the oldest value
return nextUnACKedResponse, nil
}

func (c *Conn) SetCancelCallback(cancel context.CancelFunc) {
c.handler.SetCancelCallback(cancel)
}
1 change: 1 addition & 0 deletions sync3/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func (c *connHandlerMock) Destroy() {}
func (c *connHandlerMock) Alive() bool { return true }
func (c *connHandlerMock) OnUpdate(ctx context.Context, update caches.Update) {}
func (c *connHandlerMock) PublishEventsUpTo(roomID string, nid int64) {}
func (c *connHandlerMock) SetCancelCallback(cancel context.CancelFunc) {}

// Test that Conn can send and receive requests based on positions
func TestConn(t *testing.T) {
Expand Down
4 changes: 3 additions & 1 deletion sync3/connmap.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sync3

import (
"context"
"sync"
"time"

Expand Down Expand Up @@ -131,7 +132,7 @@ func (m *ConnMap) getConn(cid ConnID) *Conn {
}

// Atomically gets or creates a connection with this connection ID. Calls newConn if a new connection is required.
func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Conn, bool) {
func (m *ConnMap) CreateConn(cid ConnID, cancel context.CancelFunc, newConnHandler func() ConnHandler) (*Conn, bool) {
// atomically check if a conn exists already and nuke it if it exists
m.mu.Lock()
defer m.mu.Unlock()
Expand All @@ -149,6 +150,7 @@ func (m *ConnMap) CreateConn(cid ConnID, newConnHandler func() ConnHandler) (*Co
m.closeConn(conn)
}
h := newConnHandler()
h.SetCancelCallback(cancel)
conn = NewConn(cid, h)
m.cache.Set(cid.String(), conn)
m.connIDToConn[cid.String()] = conn
Expand Down
13 changes: 11 additions & 2 deletions sync3/handler/connstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ type ConnState struct {
userID string
deviceID string
// the only thing that can touch these data structures is the conn goroutine
muxedReq *sync3.Request
lists *sync3.InternalRequestLists
muxedReq *sync3.Request
cancelLatestReq context.CancelFunc
lists *sync3.InternalRequestLists

// Confirmed room subscriptions. Entries in this list have been checked for things like
// "is the user joined to this room?" whereas subscriptions in muxedReq are untrusted.
Expand Down Expand Up @@ -723,6 +724,10 @@ func (s *ConnState) trackProcessDuration(ctx context.Context, dur time.Duration,
// Called when the connection is torn down
func (s *ConnState) Destroy() {
s.userCache.Unsubscribe(s.userCacheID)
logger.Debug().Str("user_id", s.userID).Str("device_id", s.deviceID).Msg("cancelling any in-flight requests")
if s.cancelLatestReq != nil {
s.cancelLatestReq()
}
}

func (s *ConnState) Alive() bool {
Expand Down Expand Up @@ -764,6 +769,10 @@ func (s *ConnState) PublishEventsUpTo(roomID string, nid int64) {
s.txnIDWaiter.PublishUpToNID(roomID, nid)
}

func (s *ConnState) SetCancelCallback(cancel context.CancelFunc) {
s.cancelLatestReq = cancel
}

// clampSliceRangeToListSize helps us to send client-friendly SYNC and INVALIDATE ranges.
//
// Suppose the client asks for a window on positions [10, 19]. If the list
Expand Down
2 changes: 1 addition & 1 deletion sync3/handler/connstate_live.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (s *connStateLive) liveUpdate(
log.Trace().Str("dur", timeLeftToWait.String()).Msg("liveUpdate: no response data yet; blocking")
select {
case <-ctx.Done(): // client has given up
log.Trace().Msg("liveUpdate: client gave up")
log.Trace().Msg("liveUpdate: client gave up, or we killed the connection")
internal.Logf(ctx, "liveUpdate", "context cancelled")
return
case <-time.After(timeLeftToWait): // we've timed out
Expand Down
9 changes: 6 additions & 3 deletions sync3/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
}
}

req, conn, herr := h.setupConnection(req, &requestBody, req.URL.Query().Get("pos") != "")
cancelCtx, cancel := context.WithCancel(req.Context())
req = req.WithContext(cancelCtx)
req, conn, herr := h.setupConnection(req, cancel, &requestBody, req.URL.Query().Get("pos") != "")
if herr != nil {
logErrorOrWarning("failed to get or create Conn", herr)
return herr
Expand Down Expand Up @@ -326,7 +328,7 @@ func (h *SyncLiveHandler) serve(w http.ResponseWriter, req *http.Request) error
// setupConnection associates this request with an existing connection or makes a new connection.
// It also sets a v2 sync poll loop going if one didn't exist already for this user.
// When this function returns, the connection is alive and active.
func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Request, containsPos bool) (*http.Request, *sync3.Conn, *internal.HandlerError) {
func (h *SyncLiveHandler) setupConnection(req *http.Request, cancel context.CancelFunc, syncReq *sync3.Request, containsPos bool) (*http.Request, *sync3.Conn, *internal.HandlerError) {
ctx, task := internal.StartTask(req.Context(), "setupConnection")
req = req.WithContext(ctx)
defer task.End()
Expand Down Expand Up @@ -387,6 +389,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// Lookup the connection
conn = h.ConnMap.Conn(connID)
if conn != nil {
conn.SetCancelCallback(cancel)
log.Trace().Str("conn", conn.ConnID.String()).Msg("reusing conn")
return req, conn, nil
}
Expand Down Expand Up @@ -434,7 +437,7 @@ func (h *SyncLiveHandler) setupConnection(req *http.Request, syncReq *sync3.Requ
// because we *either* do the existing check *or* make a new conn. It's important for CreateConn
// to check for an existing connection though, as it's possible for the client to call /sync
// twice for a new connection.
conn, created := h.ConnMap.CreateConn(connID, func() sync3.ConnHandler {
conn, created := h.ConnMap.CreateConn(connID, cancel, func() sync3.ConnHandler {
return NewConnState(token.UserID, token.DeviceID, userCache, h.GlobalCache, h.Extensions, h.Dispatcher, h.setupHistVec, h.histVec, h.maxPendingEventUpdates, h.maxTransactionIDDelay)
})
if created {
Expand Down
56 changes: 56 additions & 0 deletions tests-e2e/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package syncv3_test

import (
"context"
"errors"
"github.com/matrix-org/complement/client"
"github.com/matrix-org/sliding-sync/sync3"
"net/http"
"net/url"
"testing"
)

func TestRequestCancelledWhenItsConnIsDestroyed(t *testing.T) {
alice := registerNamedUser(t, "alice")

t.Log("Alice does an initial sliding sync.")
aliceRes := alice.SlidingSync(t, sync3.Request{})

t.Log("Alice prepares a second sliding sync request.")
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, "POST", proxyBaseURL+"/_matrix/client/unstable/org.matrix.msc3575/sync", nil)
if err != nil {
t.Fatal(err)
}

client.WithQueries(url.Values{
"timeout": []string{"1000"},
"pos": []string{aliceRes.Pos},
})(req)
client.WithRawBody([]byte("{}"))
client.WithContentType("application/json")(req)
req.Header.Set("Authorization", "Bearer "+alice.AccessToken)

done := make(chan struct{})

go func() {
t.Log("Alice makes her second sync.")
t.Log(req)
_, err := alice.Client.Do(req)
if err != nil {
t.Error(err)
}

done <- struct{}{}
}()

t.Log("Alice logs out.")
alice.MustDo(t, "POST", []string{"_matrix", "client", "v3", "logout"})

t.Log("Alice waits for her second sync response.")
<-done

if !errors.Is(ctx.Err(), context.Canceled) {
t.Logf("ctx.Err(): got %v, expected %v", ctx.Err(), context.Canceled)
}
}

0 comments on commit 759146e

Please sign in to comment.