diff --git a/gzip.go b/gzip.go index 0bd54b0..907e597 100644 --- a/gzip.go +++ b/gzip.go @@ -6,6 +6,7 @@ package gziphandler import ( "bufio" "compress/gzip" + "io" "net" "net/http" "sync" @@ -30,29 +31,24 @@ var bufferPool = &sync.Pool{ }, } -var gzipPool [BestCompression - HuffmanOnly + 1]*sync.Pool +var gzipWriterPools [gzip.BestCompression - gzip.HuffmanOnly + 1]sync.Pool -func gzipPoolIndex(level int) int { - return level - HuffmanOnly +func gzipWriterPool(level int) *sync.Pool { + return &gzipWriterPools[level-gzip.HuffmanOnly] } -func addGzipPool(level int) { - gzipPool[gzipPoolIndex(level)] = &sync.Pool{ - New: func() interface{} { - w, err := gzip.NewWriterLevel(nil, level) - if err != nil { - panic(err) - } - - return w - }, +func gzipWriterGet(w io.Writer, level int) *gzip.Writer { + if gw, ok := gzipWriterPool(level).Get().(*gzip.Writer); ok { + gw.Reset(w) + return gw } + + gw, _ := gzip.NewWriterLevel(w, level) + return gw } -func init() { - for level := HuffmanOnly; level <= BestCompression; level++ { - addGzipPool(level) - } +func gzipWriterPut(gw *gzip.Writer, level int) { + gzipWriterPool(level).Put(gw) } // These constants are copied from the gzip package, so @@ -166,8 +162,7 @@ func (w *responseWriter) startGzip() (err error) { // Bytes written during ServeHTTP are redirected to // this gzip writer before being written to the // underlying response. - w.gw = w.h.pool().Get().(*gzip.Writer) - w.gw.Reset(w.ResponseWriter) + w.gw = gzipWriterGet(w.ResponseWriter, w.h.level) if buf := *w.buf; len(buf) != 0 { // Flush the buffer into the gzip response. @@ -287,7 +282,7 @@ func (w *responseWriter) Close() error { func (w *responseWriter) closeGzipped() error { err := w.gw.Close() - w.h.pool().Put(w.gw) + gzipWriterPut(w.gw, w.h.level) w.gw = nil return err @@ -330,10 +325,6 @@ type handler struct { config } -func (h *handler) pool() *sync.Pool { - return gzipPool[gzipPoolIndex(h.level)] -} - func (h *handler) shouldGzip(r *http.Request) bool { if h.config.shouldGzip != nil { switch h.config.shouldGzip(r) { diff --git a/gzip_test.go b/gzip_test.go index ec0602b..6105dfe 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -195,9 +195,6 @@ func TestMinSizePanicsForInvalid(t *testing.T) { } func TestGzipDoubleClose(t *testing.T) { - addGzipPool(DefaultCompression) - pool := gzipPool[gzipPoolIndex(DefaultCompression)] - h := Gzip(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // call close here and it'll get called again interally by // NewGzipLevelHandler's handler defer @@ -212,8 +209,8 @@ func TestGzipDoubleClose(t *testing.T) { // the second close shouldn't have added the same writer // so we pull out 2 writers from the pool and make sure they're different - w1 := pool.Get() - w2 := pool.Get() + w1 := gzipWriterGet(nil, DefaultCompression) + w2 := gzipWriterGet(nil, DefaultCompression) // assert.NotEqual looks at the value and not the address, so we use regular == assert.False(t, w1 == w2) }