diff --git a/gzip_test.go b/gzip_test.go index 0caddc9..6e0a8cb 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -241,6 +241,49 @@ func TestStatusCodes(t *testing.T) { assert.Equal(t, http.StatusOK, result.StatusCode) } +type httpFlusherFunc func() + +func (fn httpFlusherFunc) Flush() { fn() } + +func TestFlush(t *testing.T) { + b := []byte(testBody) + handler := Gzip(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusNotFound) + rw.Write(b) + rw.(http.Flusher).Flush() + rw.Write(b) + }), MinSize(0), CompressionLevel(DefaultCompression)) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept-Encoding", "gzip") + w := httptest.NewRecorder() + + var flushed bool + handler.ServeHTTP(struct { + http.ResponseWriter + http.Flusher + }{ + w, + httpFlusherFunc(func() { flushed = true }), + }, r) + + assert.True(t, flushed, "Flush did not call underlying http.Flusher") + + res := w.Result() + assert.Equal(t, http.StatusNotFound, res.StatusCode) + assert.Equal(t, "gzip", res.Header.Get("Content-Encoding")) + + var buf bytes.Buffer + gw, _ := gzip.NewWriterLevel(&buf, DefaultCompression) + + gw.Write(b) + gw.Flush() // Flush emits a symbol into the deflate output + gw.Write(b) + gw.Close() + + assert.Equal(t, buf.Bytes(), w.Body.Bytes()) +} + func TestFlushBeforeWrite(t *testing.T) { b := []byte(testBody) handler := Gzip(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {