diff --git a/session.go b/session.go index fe0c2ac..b225846 100644 --- a/session.go +++ b/session.go @@ -163,6 +163,9 @@ func (s *SessionManager) LoadAndSave(next http.Handler) http.Handler { if !sw.written { s.commitAndWriteSessionCookie(w, sr) } + + // session could be modified after some response is written + s.commitIfModified(sr) }) } @@ -183,6 +186,15 @@ func (s *SessionManager) commitAndWriteSessionCookie(w http.ResponseWriter, r *h } } +func (s *SessionManager) commitIfModified(r *http.Request) { + ctx := r.Context() + + if s.Status(ctx) == Modified { + // since the header is already written, it's not possible to write the cookie + _, _, _ = s.Commit(ctx) + } +} + // WriteSessionCookie writes a cookie to the HTTP response with the provided // token as the cookie value and expiry as the cookie expiry time. The expiry // time will be included in the cookie only if the session is set to persist diff --git a/session_test.go b/session_test.go index f702cf7..4fd2321 100644 --- a/session_test.go +++ b/session_test.go @@ -330,7 +330,6 @@ func TestIterate(t *testing.T) { results = append(results, i) return nil }) - if err != nil { t.Fatal(err) } @@ -348,3 +347,45 @@ func TestIterate(t *testing.T) { t.Fatal("didn't get expected error") } } + +func TestFlushPop(t *testing.T) { + t.Parallel() + + sessionManager := New() + + mux := http.NewServeMux() + mux.HandleFunc("/put", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sessionManager.Put(r.Context(), "foo", "bar") + })) + mux.HandleFunc("/get", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("")) + s, _ := sessionManager.Pop(r.Context(), "foo").(string) + w.Write([]byte(s)) + })) + + ts := newTestServer(t, sessionManager.LoadAndSave(mux)) + defer ts.Close() + + header, _ := ts.execute(t, "/put") + token := extractTokenFromCookie(header.Get("Set-Cookie")) + + header, body := ts.execute(t, "/get") + if body != "bar" { + t.Errorf("want %q; got %q", "bar", body) + } + + cookie := header.Get("Set-Cookie") + if cookie == "" || extractTokenFromCookie(cookie) != token { + t.Errorf("want %q; got %q", token, cookie) + } + + header, body = ts.execute(t, "/get") + if body != "" { + t.Errorf("want %q; got %q", "", body) + } + + cookie = header.Get("Set-Cookie") + if cookie != "" { + t.Errorf("want %q; got %q", "", cookie) + } +}