From 20e8232ca6024a8fa72d73f156729c765c975203 Mon Sep 17 00:00:00 2001 From: Ethan Leisinger <770373+packruler@users.noreply.github.com> Date: Sat, 16 Apr 2022 17:17:27 -0600 Subject: [PATCH] Restructure system for improved maintenance going forward (#1) --- compressutil/compressutil.go | 92 ++++++++++++ compressutil/compressutil_test.go | 154 ++++++++++++++++++++ httputil/response_writer.go | 129 +++++++++++++++++ rewritebody.go | 227 +++--------------------------- rewritebody_test.go | 32 ++++- 5 files changed, 421 insertions(+), 213 deletions(-) create mode 100644 compressutil/compressutil.go create mode 100644 compressutil/compressutil_test.go create mode 100644 httputil/response_writer.go diff --git a/compressutil/compressutil.go b/compressutil/compressutil.go new file mode 100644 index 0000000..a1f8990 --- /dev/null +++ b/compressutil/compressutil.go @@ -0,0 +1,92 @@ +// Package compressutil a plugin to handle compression and decompression tasks +package compressutil + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "io" + "log" +) + +// ReaderError for notating that an error occurred while reading compressed data. +type ReaderError struct { + error + + cause error +} + +// Decode data in a bytes.Reader based on supplied encoding. +func Decode(byteReader *bytes.Buffer, encoding string) (data []byte, err error) { + reader, err := getRawReader(byteReader, encoding) + if err != nil { + return nil, &ReaderError{cause: err} + } + + return io.ReadAll(reader) +} + +func getRawReader(byteReader *bytes.Buffer, encoding string) (io.Reader, error) { + switch encoding { + case "gzip": + return gzip.NewReader(byteReader) + + case "deflate": + return flate.NewReader(byteReader), nil + + default: + return byteReader, nil + } +} + +// Encode data in a []byte based on supplied encoding. +func Encode(data []byte, encoding string) ([]byte, error) { + switch encoding { + case "gzip": + return compressWithGzip(data) + + case "deflate": + return compressWithZlib(data) + + default: + return data, nil + } +} + +func compressWithGzip(bodyBytes []byte) ([]byte, error) { + var buf bytes.Buffer + gzipWriter := gzip.NewWriter(&buf) + + if _, err := gzipWriter.Write(bodyBytes); err != nil { + log.Printf("unable to recompress rewrited body: %v", err) + + return nil, err + } + + if err := gzipWriter.Close(); err != nil { + log.Printf("unable to close gzip writer: %v", err) + + return nil, err + } + + return buf.Bytes(), nil +} + +func compressWithZlib(bodyBytes []byte) ([]byte, error) { + var buf bytes.Buffer + zlibWriter, _ := flate.NewWriter(&buf, flate.DefaultCompression) + + if _, err := zlibWriter.Write(bodyBytes); err != nil { + log.Printf("unable to recompress rewrited body: %v", err) + + return nil, err + } + + if err := zlibWriter.Close(); err != nil { + log.Printf("unable to close zlib writer: %v", err) + + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/compressutil/compressutil_test.go b/compressutil/compressutil_test.go new file mode 100644 index 0000000..9c8d3f3 --- /dev/null +++ b/compressutil/compressutil_test.go @@ -0,0 +1,154 @@ +package compressutil_test + +import ( + "bytes" + "testing" + + "github.com/packruler/rewrite-body/compressutil" +) + +type TestStruct struct { + desc string + input []byte + expected []byte + encoding string + shouldMatch bool +} + +func TestEncode(t *testing.T) { + var ( + deflatedBytes = []byte{ + 74, 203, 207, 87, 200, 44, 86, 40, 201, 72, 85, + 200, 75, 45, 87, 72, 74, 44, 2, 4, 0, 0, 255, 255, + } + gzippedBytes = []byte{ + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 74, 203, 207, 87, 200, 44, 86, 40, 201, 72, 85, + 200, 75, 45, 87, 72, 74, 44, 2, 4, 0, 0, 255, 255, 251, 28, 166, 187, 18, 0, 0, 0, + } + normalBytes = []byte("foo is the new bar") + ) + + tests := []TestStruct{ + { + desc: "should support identity", + input: normalBytes, + expected: normalBytes, + encoding: "identity", + shouldMatch: true, + }, + { + desc: "should support gzip", + input: normalBytes, + expected: gzippedBytes, + encoding: "gzip", + shouldMatch: false, + }, + { + desc: "should support deflate", + input: normalBytes, + expected: deflatedBytes, + encoding: "deflate", + shouldMatch: false, + }, + { + desc: "should NOT support brotli", + input: normalBytes, + expected: normalBytes, + encoding: "br", + shouldMatch: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + output, err := compressutil.Encode(test.input, test.encoding) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + isBad := !bytes.Equal(test.expected, output) + + if isBad { + t.Errorf("expected error got body: %v\n wanted: %v", output, test.expected) + } + + if test.shouldMatch { + isBad = !bytes.Equal(test.input, output) + } else { + isBad = bytes.Equal(test.input, output) + } + if isBad { + t.Errorf("match error got body: %v\n wanted: %v", output, test.input) + } + }) + } +} + +func TestDecode(t *testing.T) { + var ( + deflatedBytes = []byte{ + 74, 203, 207, 87, 200, 44, 86, 40, 201, 72, 85, + 200, 75, 45, 87, 72, 74, 44, 2, 4, 0, 0, 255, 255, + } + gzippedBytes = []byte{ + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 74, 203, 207, 87, 200, 44, 86, 40, 201, 72, 85, + 200, 75, 45, 87, 72, 74, 44, 2, 4, 0, 0, 255, 255, 251, 28, 166, 187, 18, 0, 0, 0, + } + normalBytes = []byte("foo is the new bar") + ) + + tests := []TestStruct{ + { + desc: "should support identity", + input: normalBytes, + expected: normalBytes, + encoding: "identity", + shouldMatch: true, + }, + { + desc: "should support gzip", + input: gzippedBytes, + expected: normalBytes, + encoding: "gzip", + shouldMatch: false, + }, + { + desc: "should support deflate", + input: deflatedBytes, + expected: normalBytes, + encoding: "deflate", + shouldMatch: false, + }, + { + desc: "should NOT support brotli", + input: normalBytes, + expected: normalBytes, + encoding: "br", + shouldMatch: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + output, err := compressutil.Decode(bytes.NewBuffer(test.input), test.encoding) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + isBad := !bytes.Equal(test.expected, output) + + if isBad { + t.Errorf("expected error got body: %v\n wanted: %v", output, test.expected) + } + + if test.shouldMatch { + isBad = !bytes.Equal(test.input, output) + } else { + isBad = bytes.Equal(test.input, output) + } + if isBad { + t.Errorf("match error got body: %s\n wanted: %s", output, test.input) + } + }) + } +} diff --git a/httputil/response_writer.go b/httputil/response_writer.go new file mode 100644 index 0000000..c8599e7 --- /dev/null +++ b/httputil/response_writer.go @@ -0,0 +1,129 @@ +// Package httputil a package for handling http data tasks +package httputil + +import ( + "bytes" + "log" + "net/http" + "strings" + + "github.com/packruler/rewrite-body/compressutil" +) + +// ResponseWrapper a wrapper used to simplify ResponseWriter data access and manipulation. +type ResponseWrapper struct { + buffer bytes.Buffer + lastModified bool + wroteHeader bool + + http.ResponseWriter +} + +// WriteHeader into wrapped ResponseWriter. +func (wrapper *ResponseWrapper) WriteHeader(statusCode int) { + if !wrapper.lastModified { + wrapper.ResponseWriter.Header().Del("Last-Modified") + } + + wrapper.wroteHeader = true + + // Delegates the Content-Length Header creation to the final body write. + wrapper.ResponseWriter.Header().Del("Content-Length") + + wrapper.ResponseWriter.WriteHeader(statusCode) +} + +// Write data to internal buffer and mark the status code as http.StatusOK. +func (wrapper *ResponseWrapper) Write(data []byte) (int, error) { + if !wrapper.wroteHeader { + wrapper.WriteHeader(http.StatusOK) + } + + return wrapper.buffer.Write(data) +} + +// GetBuffer get a pointer to the ResponseWriter buffer. +func (wrapper *ResponseWrapper) GetBuffer() *bytes.Buffer { + return &wrapper.buffer +} + +// GetContent load the content currently in the internal buffer +// accounting for possible encoding. +func (wrapper *ResponseWrapper) GetContent() ([]byte, error) { + encoding := wrapper.GetContentEncoding() + + return compressutil.Decode(wrapper.GetBuffer(), encoding) +} + +// SetContent write data to the internal ResponseWriter buffer +// and match initial encoding. +func (wrapper *ResponseWrapper) SetContent(data []byte) { + encoding := wrapper.GetContentEncoding() + + bodyBytes, _ := compressutil.Encode(data, encoding) + + if !wrapper.wroteHeader { + wrapper.WriteHeader(http.StatusOK) + } + + if _, err := wrapper.ResponseWriter.Write(bodyBytes); err != nil { + log.Printf("unable to write rewrited body: %v", err) + } +} + +// SupportsProcessing determine if http.Request is supported by this plugin. +func SupportsProcessing(request *http.Request) bool { + // Ignore non GET requests + if request.Method != "GET" { + return false + } + + if strings.Contains(request.Header.Get("Upgrade"), "websocket") { + log.Printf("Ignoring websocket request for %s", request.RequestURI) + + return false + } + + return true +} + +// GetContentEncoding get the Content-Encoding header value. +func (wrapper *ResponseWrapper) GetContentEncoding() string { + return wrapper.Header().Get("Content-Encoding") +} + +// GetContentType get the Content-Encoding header value. +func (wrapper *ResponseWrapper) GetContentType() string { + return wrapper.Header().Get("Content-Type") +} + +// SupportsProcessing determine if HttpWrapper is supported by this plugin based on encoding. +func (wrapper *ResponseWrapper) SupportsProcessing() bool { + contentType := wrapper.GetContentType() + + // If content type does not match return values with false + if contentType != "" && !strings.Contains(contentType, "text") { + return false + } + + encoding := wrapper.GetContentEncoding() + + // If content type is supported validate encoding as well + switch encoding { + case "gzip": + fallthrough + case "deflate": + fallthrough + case "identity": + fallthrough + case "": + return true + default: + return false + } +} + +// SetLastModified update the local lastModified variable from non-package-based users. +func (wrapper *ResponseWrapper) SetLastModified(value bool) { + wrapper.lastModified = value +} diff --git a/rewritebody.go b/rewritebody.go index 09155cf..beae6e6 100644 --- a/rewritebody.go +++ b/rewritebody.go @@ -1,19 +1,14 @@ -// Package plugin_rewritebody a plugin to rewrite response body. -package plugin_rewritebody +// Package rewrite_body a plugin to rewrite response body. +package rewrite_body import ( - "bufio" - "bytes" - "compress/gzip" - "compress/zlib" "context" "fmt" - "io" "log" - "net" "net/http" "regexp" - "strings" + + "github.com/packruler/rewrite-body/httputil" ) // Rewrite holds one rewrite body configuration. @@ -70,225 +65,41 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt } func (bodyRewrite *rewriteBody) ServeHTTP(response http.ResponseWriter, req *http.Request) { - if req.Header.Get("Upgrade") == "websocket" { - log.Printf("Ignoring websocket upgrade| Host: \"%s\" | Path: \"%s\"", req.Host, req.URL.Path) - + if !httputil.SupportsProcessing(req) { return } - wrappedWriter := &responseWriter{ - lastModified: bodyRewrite.lastModified, + wrappedWriter := &httputil.ResponseWrapper{ ResponseWriter: response, } - bodyRewrite.next.ServeHTTP(wrappedWriter, req) + wrappedWriter.SetLastModified(bodyRewrite.lastModified) - encoding, _, isSupported := wrappedWriter.getHeaderContent() + bodyRewrite.next.ServeHTTP(wrappedWriter, req) + isSupported := wrappedWriter.SupportsProcessing() if !isSupported { - if _, err := response.Write(wrappedWriter.buffer.Bytes()); err != nil { + if _, err := response.Write(wrappedWriter.GetBuffer().Bytes()); err != nil { log.Printf("unable to write body: %v", err) } return } - bodyBytes, ok := wrappedWriter.decompressBody(encoding) - if ok { - for _, rwt := range bodyRewrite.rewrites { - bodyBytes = rwt.regex.ReplaceAll(bodyBytes, rwt.replacement) - } - - bodyBytes = prepareBodyBytes(bodyBytes, encoding) - } else { - bodyBytes = wrappedWriter.buffer.Bytes() - } - - if _, err := response.Write(bodyBytes); err != nil { - log.Printf("unable to write rewrited body: %v", err) - } -} - -func (wrappedWriter *responseWriter) getHeaderContent() (encoding string, contentType string, isSupported bool) { - encoding = wrappedWriter.Header().Get("Content-Encoding") - contentType = wrappedWriter.Header().Get("Content-Type") - - // If content type does not match return values with false - if contentType != "" && !strings.Contains(contentType, "text") { - return encoding, contentType, false - } - - // If content type is supported validate encoding as well - switch encoding { - case "gzip": - fallthrough - case "deflate": - fallthrough - case "identity": - fallthrough - case "": - return encoding, contentType, true - default: - return encoding, contentType, false - } -} - -func (wrappedWriter *responseWriter) decompressBody(encoding string) ([]byte, bool) { - switch encoding { - case "gzip": - return getBytesFromGzip(wrappedWriter.buffer) - - case "deflate": - return getBytesFromZlib(wrappedWriter.buffer) - - default: - return wrappedWriter.buffer.Bytes(), true - } -} - -func getBytesFromZlib(buffer bytes.Buffer) ([]byte, bool) { - zlibReader, err := zlib.NewReader(&buffer) - if err != nil { - log.Printf("Failed to load body reader: %v", err) - - return buffer.Bytes(), false - } - - bodyBytes, err := io.ReadAll(zlibReader) - if err != nil { - log.Printf("Failed to read body: %s", err) - - return buffer.Bytes(), false - } - - err = zlibReader.Close() - - if err != nil { - log.Printf("Failed to close reader: %v", err) - - return buffer.Bytes(), false - } - - return bodyBytes, true -} - -func getBytesFromGzip(buffer bytes.Buffer) ([]byte, bool) { - gzipReader, err := gzip.NewReader(&buffer) - if err != nil { - log.Printf("Failed to load body reader: %v", err) - - return buffer.Bytes(), false - } - - bodyBytes, err := io.ReadAll(gzipReader) - if err != nil { - log.Printf("Failed to read body: %s", err) - - return buffer.Bytes(), false - } - - err = gzipReader.Close() - + bodyBytes, err := wrappedWriter.GetContent() if err != nil { - log.Printf("Failed to close reader: %v", err) - - return buffer.Bytes(), false - } - - return bodyBytes, true -} - -func prepareBodyBytes(bodyBytes []byte, encoding string) []byte { - switch encoding { - case "gzip": - return compressWithGzip(bodyBytes) - - case "deflate": - return compressWithZlib(bodyBytes) - - default: - return bodyBytes - } -} - -func compressWithGzip(bodyBytes []byte) []byte { - var buf bytes.Buffer - gzipWriter := gzip.NewWriter(&buf) - - if _, err := gzipWriter.Write(bodyBytes); err != nil { - log.Printf("unable to recompress rewrited body: %v", err) - - return bodyBytes - } - - if err := gzipWriter.Close(); err != nil { - log.Printf("unable to close gzip writer: %v", err) - - return bodyBytes - } - - return buf.Bytes() -} - -func compressWithZlib(bodyBytes []byte) []byte { - var buf bytes.Buffer - zlibWriter := zlib.NewWriter(&buf) - - if _, err := zlibWriter.Write(bodyBytes); err != nil { - log.Printf("unable to recompress rewrited body: %v", err) - - return bodyBytes - } + log.Printf("Error loading content: %v", err) - if err := zlibWriter.Close(); err != nil { - log.Printf("unable to close zlib writer: %v", err) - - return bodyBytes - } - - return buf.Bytes() -} - -type responseWriter struct { - buffer bytes.Buffer - lastModified bool - wroteHeader bool - - http.ResponseWriter -} - -func (wrappedWriter *responseWriter) WriteHeader(statusCode int) { - if !wrappedWriter.lastModified { - wrappedWriter.ResponseWriter.Header().Del("Last-Modified") - } - - wrappedWriter.wroteHeader = true - - // Delegates the Content-Length Header creation to the final body write. - wrappedWriter.ResponseWriter.Header().Del("Content-Length") - - wrappedWriter.ResponseWriter.WriteHeader(statusCode) -} + if _, err := response.Write(wrappedWriter.GetBuffer().Bytes()); err != nil { + log.Printf("unable to write body: %v", err) + } -func (wrappedWriter *responseWriter) Write(p []byte) (int, error) { - if !wrappedWriter.wroteHeader { - wrappedWriter.WriteHeader(http.StatusOK) + return } - return wrappedWriter.buffer.Write(p) -} - -func (wrappedWriter *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hijacker, ok := wrappedWriter.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, fmt.Errorf("%T is not a http.Hijacker", wrappedWriter.ResponseWriter) + for _, rwt := range bodyRewrite.rewrites { + bodyBytes = rwt.regex.ReplaceAll(bodyBytes, rwt.replacement) } - return hijacker.Hijack() -} - -func (wrappedWriter *responseWriter) Flush() { - if flusher, ok := wrappedWriter.ResponseWriter.(http.Flusher); ok { - flusher.Flush() - } + wrappedWriter.SetContent(bodyBytes) } diff --git a/rewritebody_test.go b/rewritebody_test.go index 87885c9..18a7079 100644 --- a/rewritebody_test.go +++ b/rewritebody_test.go @@ -1,4 +1,4 @@ -package plugin_rewritebody +package rewrite_body import ( "bytes" @@ -8,6 +8,8 @@ import ( "net/http/httptest" "strconv" "testing" + + "github.com/packruler/rewrite-body/compressutil" ) func TestServeHTTP(t *testing.T) { @@ -108,8 +110,8 @@ func TestServeHTTP(t *testing.T) { }, contentEncoding: "gzip", lastModified: true, - resBody: string(compressWithGzip([]byte("foo is the new bar"))), - expResBody: string(compressWithGzip([]byte("bar is the new bar"))), + resBody: compressString("foo is the new bar", "gzip"), + expResBody: compressString("bar is the new bar", "gzip"), expLastModified: true, }, { @@ -122,8 +124,22 @@ func TestServeHTTP(t *testing.T) { }, contentEncoding: "deflate", lastModified: true, - resBody: string(compressWithZlib([]byte("foo is the new bar"))), - expResBody: string(compressWithZlib([]byte("bar is the new bar"))), + resBody: compressString("foo is the new bar", "deflate"), + expResBody: compressString("bar is the new bar", "deflate"), + expLastModified: true, + }, + { + desc: "should ignore unsupported encoding", + rewrites: []Rewrite{ + { + Regex: "foo", + Replacement: "bar", + }, + }, + contentEncoding: "br", + lastModified: true, + resBody: "foo is the new bar", + expResBody: "foo is the new bar", expLastModified: true, }, } @@ -170,6 +186,12 @@ func TestServeHTTP(t *testing.T) { } } +func compressString(value string, encoding string) string { + compressed, _ := compressutil.Encode([]byte(value), encoding) + + return string(compressed) +} + func TestNew(t *testing.T) { tests := []struct { desc string