diff --git a/routesrv/eskipbytes.go b/routesrv/eskipbytes.go index 37ebdc664a..da9f47c593 100644 --- a/routesrv/eskipbytes.go +++ b/routesrv/eskipbytes.go @@ -2,10 +2,12 @@ package routesrv import ( "bytes" + "compress/gzip" "crypto/sha256" "fmt" "net/http" "strconv" + "strings" "sync" "time" @@ -50,12 +52,15 @@ var ( // provides synchronized r/w access to them. Additionally it can // serve as an HTTP handler exposing its content. type eskipBytes struct { + mu sync.RWMutex data []byte - etag string + hash string lastModified time.Time initialized bool count int - mu sync.RWMutex + + zw *gzip.Writer + zdata []byte tracer ot.Tracer metrics metrics.Metrics @@ -77,13 +82,33 @@ func (e *eskipBytes) formatAndSet(routes []*eskip.Route) (_ int, _ string, initi if updated { e.lastModified = e.now() e.data = data - e.etag = fmt.Sprintf(`"%x"`, sha256.Sum256(e.data)) + e.zdata = e.compressLocked(data) + e.hash = fmt.Sprintf("%x", sha256.Sum256(e.data)) e.count = len(routes) } initialized = !e.initialized e.initialized = true - return len(e.data), e.etag, initialized, updated + return len(e.data), e.hash, initialized, updated +} + +// compressLocked compresses the data with gzip and returns +// the compressed data or nil if compression fails. +// e.mu must be held. +func (e *eskipBytes) compressLocked(data []byte) []byte { + var buf bytes.Buffer + if e.zw == nil { + e.zw = gzip.NewWriter(&buf) + } else { + e.zw.Reset(&buf) + } + if _, err := e.zw.Write(data); err != nil { + return nil + } + if err := e.zw.Close(); err != nil { + return nil + } + return buf.Bytes() } func (e *eskipBytes) ServeHTTP(rw http.ResponseWriter, r *http.Request) { @@ -112,17 +137,24 @@ func (e *eskipBytes) ServeHTTP(rw http.ResponseWriter, r *http.Request) { e.mu.RLock() count := e.count data := e.data - etag := e.etag + zdata := e.zdata + hash := e.hash lastModified := e.lastModified initialized := e.initialized e.mu.RUnlock() if initialized { - w.Header().Add("Etag", etag) - w.Header().Add("Content-Type", "text/plain; charset=utf-8") - w.Header().Add(routing.RoutesCountName, strconv.Itoa(count)) - - http.ServeContent(w, r, "", lastModified, bytes.NewReader(data)) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set(routing.RoutesCountName, strconv.Itoa(count)) + + if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") && len(zdata) > 0 { + w.Header().Set("Etag", `"`+hash+`+gzip"`) + w.Header().Set("Content-Encoding", "gzip") + http.ServeContent(w, r, "", lastModified, bytes.NewReader(zdata)) + } else { + w.Header().Set("Etag", `"`+hash+`"`) + http.ServeContent(w, r, "", lastModified, bytes.NewReader(data)) + } } else { w.WriteHeader(http.StatusNotFound) } diff --git a/routesrv/polling.go b/routesrv/polling.go index 00895cf0ac..28702e5da2 100644 --- a/routesrv/polling.go +++ b/routesrv/polling.go @@ -78,8 +78,8 @@ func (p *poller) poll(wg *sync.WaitGroup) { "message", LogRoutesEmpty, ) case routesCount > 0: - routesBytes, routesEtag, initialized, updated := p.b.formatAndSet(routes) - logger := log.WithFields(log.Fields{"count": routesCount, "bytes": routesBytes, "etag": routesEtag}) + routesBytes, routesHash, initialized, updated := p.b.formatAndSet(routes) + logger := log.WithFields(log.Fields{"count": routesCount, "bytes": routesBytes, "hash": routesHash}) if initialized { logger.Info(LogRoutesInitialized) span.SetTag("routes.initialized", true) @@ -94,7 +94,7 @@ func (p *poller) poll(wg *sync.WaitGroup) { } span.SetTag("routes.count", routesCount) span.SetTag("routes.bytes", routesBytes) - span.SetTag("routes.etag", routesEtag) + span.SetTag("routes.hash", routesHash) if updated && log.IsLevelEnabled(log.DebugLevel) { routesById := mapRoutes(routes) diff --git a/routesrv/routesrv_test.go b/routesrv/routesrv_test.go index e1ed960ff2..0326932519 100644 --- a/routesrv/routesrv_test.go +++ b/routesrv/routesrv_test.go @@ -2,6 +2,7 @@ package routesrv_test import ( "bytes" + "compress/gzip" "flag" "io" "net/http" @@ -16,6 +17,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/zalando/skipper" "github.com/zalando/skipper/dataclients/kubernetes/kubernetestest" @@ -724,3 +726,93 @@ func TestRoutesWithExplicitLBAlgorithm(t *testing.T) { } wantHTTPCode(t, responseRecorder, http.StatusOK) } + +func TestESkipBytesHandlerGzip(t *testing.T) { + defer tl.Reset() + ks, handler := newKubeServer(t, loadKubeYAML(t, "testdata/lb-target-multi.yaml")) + ks.Start() + defer ks.Close() + rs := newRouteServer(t, ks) + + rs.StartUpdates() + defer rs.StopUpdates() + + testGzipResponse := func(t *testing.T, count int) { + // Get plain response + plainResponse := getRoutes(rs) + plainEtag := plainResponse.Header().Get("Etag") + plainContent := plainResponse.Body.Bytes() + + // Get gzip response + gzipResponse := getRoutesWithRequestHeadersSetting(rs, map[string]string{"Accept-Encoding": "gzip"}) + assert.Equal(t, http.StatusOK, gzipResponse.Code) + assert.Equal(t, "text/plain; charset=utf-8", gzipResponse.Header().Get("Content-Type")) + assert.Equal(t, "gzip", gzipResponse.Header().Get("Content-Encoding")) + assert.Equal(t, strconv.Itoa(count), gzipResponse.Header().Get("X-Count")) + + gzipEtag := gzipResponse.Header().Get("Etag") + assert.NotEqual(t, plainEtag, gzipEtag, "gzip Etag should differ from plain Etag") + + zr, err := gzip.NewReader(gzipResponse.Body) + require.NoError(t, err) + defer zr.Close() + + gzipContent, err := io.ReadAll(zr) + require.NoError(t, err) + + assert.Equal(t, plainContent, gzipContent, "gzip content should be equal to plain content") + + // Get gzip response using Etag + gzipEtagResponse := getRoutesWithRequestHeadersSetting(rs, map[string]string{"If-None-Match": gzipEtag, "Accept-Encoding": "gzip"}) + + assert.Equal(t, http.StatusNotModified, gzipEtagResponse.Code) + // RFC 7232 section 4.1: + assert.Empty(t, gzipEtagResponse.Header().Get("Content-Type")) + assert.Empty(t, gzipEtagResponse.Header().Get("Content-Length")) + assert.Empty(t, gzipEtagResponse.Header().Get("Content-Encoding")) + assert.Equal(t, strconv.Itoa(count), gzipEtagResponse.Header().Get("X-Count")) + assert.Empty(t, gzipEtagResponse.Body.String()) + } + + require.NoError(t, tl.WaitFor(routesrv.LogRoutesInitialized, waitTimeout)) + testGzipResponse(t, 3) + + handler.set(newKubeAPI(t, loadKubeYAML(t, "testdata/lb-target-single.yaml"))) + require.NoError(t, tl.WaitForN(routesrv.LogRoutesUpdated, 2, waitTimeout)) + + testGzipResponse(t, 2) +} + +func TestESkipBytesHandlerGzipServedForDefaultClient(t *testing.T) { + defer tl.Reset() + ks, _ := newKubeServer(t, loadKubeYAML(t, "testdata/lb-target-multi.yaml")) + ks.Start() + defer ks.Close() + + rs, err := routesrv.New(skipper.Options{ + SourcePollTimeout: pollInterval, + KubernetesURL: ks.URL, + }) + require.NoError(t, err) + + rs.StartUpdates() + defer rs.StopUpdates() + + require.NoError(t, tl.WaitFor(routesrv.LogRoutesInitialized, waitTimeout)) + + ts := httptest.NewServer(rs) + defer ts.Close() + + resp, err := ts.Client().Get(ts.URL + "/routes") + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.True(t, resp.Uncompressed, "expected uncompressed body") + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + routes, err := eskip.Parse(string(b)) + require.NoError(t, err) + assert.Len(t, routes, 3) +}