From 4425ef71bbc570e4f51122a91fd5ee7428aba5d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sandor=20Sz=C3=BCcs?= Date: Fri, 5 Jan 2024 14:27:14 +0100 Subject: [PATCH] feature: io package inspect stream reader refactor: block* filters to use the new io package refactor: remove custom matcher MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sandor Szücs --- filters/block/block.go | 68 ++++-- filters/block/block_test.go | 340 ++++++++++++++-------------- filters/block/matcher.go | 209 ------------------ io/helper_test.go | 15 ++ io/read_stream.go | 215 ++++++++++++++++++ io/read_stream_test.go | 429 ++++++++++++++++++++++++++++++++++++ proxy/proxy.go | 4 +- 7 files changed, 883 insertions(+), 397 deletions(-) delete mode 100644 filters/block/matcher.go create mode 100644 io/helper_test.go create mode 100644 io/read_stream.go create mode 100644 io/read_stream_test.go diff --git a/filters/block/block.go b/filters/block/block.go index f01b59ce63..48b0ab28b2 100644 --- a/filters/block/block.go +++ b/filters/block/block.go @@ -1,14 +1,12 @@ package block import ( + "bytes" "encoding/hex" - "errors" "github.com/zalando/skipper/filters" -) - -var ( - ErrClosed = errors.New("reader closed") + "github.com/zalando/skipper/io" + "github.com/zalando/skipper/metrics" ) type blockSpec struct { @@ -16,10 +14,17 @@ type blockSpec struct { hex bool } +type toBlockKeys struct{ Str []byte } + +func (b toBlockKeys) String() string { + return string(b.Str) +} + type block struct { - toblockList []toblockKeys + toblockList []toBlockKeys maxEditorBuffer uint64 - maxBufferHandling maxBufferHandling + maxBufferHandling io.MaxBufferHandling + metrics metrics.Metrics } // NewBlockFilter *deprecated* version of NewBlock @@ -52,7 +57,7 @@ func (bs *blockSpec) CreateFilter(args []interface{}) (filters.Filter, error) { return nil, filters.ErrInvalidFilterParameters } - sargs := make([]toblockKeys, 0, len(args)) + sargs := make([]toBlockKeys, 0, len(args)) for _, w := range args { v, ok := w.(string) if !ok { @@ -63,33 +68,50 @@ func (bs *blockSpec) CreateFilter(args []interface{}) (filters.Filter, error) { if err != nil { return nil, err } - sargs = append(sargs, toblockKeys{str: a}) + sargs = append(sargs, toBlockKeys{Str: a}) } else { - sargs = append(sargs, toblockKeys{str: []byte(v)}) + sargs = append(sargs, toBlockKeys{Str: []byte(v)}) } } - b := &block{ + return &block{ toblockList: sargs, - maxBufferHandling: maxBufferBestEffort, + maxBufferHandling: io.MaxBufferBestEffort, maxEditorBuffer: bs.MaxMatcherBufferSize, - } + metrics: metrics.Default, + }, nil +} - return *b, nil +func blockMatcher(m metrics.Metrics, matches []toBlockKeys) func(b []byte) (int, error) { + return func(b []byte) (int, error) { + for _, s := range matches { + s := s + if bytes.Contains(b, s.Str) { + m.IncCounter("blocked.requests") + return 0, io.ErrBlocked + } + } + return len(b), nil + } } -func (b block) Request(ctx filters.FilterContext) { +func (b *block) Request(ctx filters.FilterContext) { req := ctx.Request() if req.ContentLength == 0 { return } - - req.Body = newMatcher( - req.Body, - b.toblockList, - b.maxEditorBuffer, - b.maxBufferHandling, - ) + // fix filter chaining - https://github.com/zalando/skipper/issues/2605 + ctx.Request().Header.Del("Content-Length") + ctx.Request().ContentLength = -1 + + req.Body = io.InspectReader( + req.Context(), + io.BufferOptions{ + MaxBufferHandling: b.maxBufferHandling, + ReadBufferSize: b.maxEditorBuffer, + }, + blockMatcher(b.metrics, b.toblockList), + req.Body) } -func (block) Response(filters.FilterContext) {} +func (*block) Response(filters.FilterContext) {} diff --git a/filters/block/block_test.go b/filters/block/block_test.go index ed858bdc98..a5b2225728 100644 --- a/filters/block/block_test.go +++ b/filters/block/block_test.go @@ -2,16 +2,17 @@ package block import ( "bytes" + "fmt" "io" "net/http" "net/http/httptest" - "net/url" "strings" "testing" "github.com/zalando/skipper/eskip" "github.com/zalando/skipper/filters" - "github.com/zalando/skipper/proxy" + skpio "github.com/zalando/skipper/io" + "github.com/zalando/skipper/metrics" "github.com/zalando/skipper/proxy/proxytest" ) @@ -39,22 +40,26 @@ func TestMatcher(t *testing.T) { { name: "empty string", content: "", + block: []byte(".class"), err: nil, }, { name: "small string", content: ".class", - err: proxy.ErrBlocked, + block: []byte(".class"), + err: skpio.ErrBlocked, }, { name: "small string without match", content: "foxi", + block: []byte(".class"), err: nil, }, { name: "small string with match", content: "fox.class.foo.blah", - err: proxy.ErrBlocked, + block: []byte(".class"), + err: skpio.ErrBlocked, }, { name: "hex string 0x00 without match", @@ -65,42 +70,43 @@ func TestMatcher(t *testing.T) { name: "hex string 0x00 with match", content: "fox.c\x00.foo.blah", block: []byte("\x00"), - err: proxy.ErrBlocked, + err: skpio.ErrBlocked, }, { name: "hex string with uppercase match content string with lowercase", content: "fox.c\x0A.foo.blah", block: []byte("\x0a"), - err: proxy.ErrBlocked, + err: skpio.ErrBlocked, }, { name: "hex string 0x00 0x0a with match", content: "fox.c\x00\x0a.foo.blah", block: []byte{0, 10}, - err: proxy.ErrBlocked, + err: skpio.ErrBlocked, }, { name: "long string", content: strings.Repeat("A", 8192), + block: []byte(".class"), }} { t.Run(tt.name, func(t *testing.T) { - block := []byte(".class") - if len(tt.block) != 0 { - block = tt.block - } r := &nonBlockingReader{initialContent: []byte(tt.content)} - toblockList := []toblockKeys{{str: block}} + toblockList := []toBlockKeys{{Str: tt.block}} + + req, err := http.NewRequest("POST", "http://test.example", r) + if err != nil { + t.Fatalf("Failed to create request with body: %v", err) + } - bmb := newMatcher(r, toblockList, 2097152, maxBufferBestEffort) + bmb := skpio.InspectReader(req.Context(), skpio.BufferOptions{MaxBufferHandling: skpio.MaxBufferBestEffort}, blockMatcher(metrics.Default, toblockList), req.Body) - t.Logf("Content: %s", r.initialContent) p := make([]byte, len(r.initialContent)) n, err := bmb.Read(p) if err != tt.err { t.Fatalf("Failed to get expected err %v, got: %v", tt.err, err) } if err != nil { - if err == proxy.ErrBlocked { + if err == skpio.ErrBlocked { t.Logf("Stop! Request has some blocked content!") } else { t.Errorf("Failed to read: %v", err) @@ -108,60 +114,10 @@ func TestMatcher(t *testing.T) { } else if n != len(tt.content) { t.Errorf("Failed to read content length %d, got %d", len(tt.content), n) } - }) } } -func TestMatcherErrorCases(t *testing.T) { - toblockList := []toblockKeys{{str: []byte(".class")}} - t.Run("maxBufferAbort", func(t *testing.T) { - r := &nonBlockingReader{initialContent: []byte("fppppppppp .class")} - bmb := newMatcher(r, toblockList, 5, maxBufferAbort) - p := make([]byte, len(r.initialContent)) - _, err := bmb.Read(p) - if err != ErrMatcherBufferFull { - t.Errorf("Failed to get expected error %v, got: %v", ErrMatcherBufferFull, err) - } - }) - - t.Run("maxBuffer", func(t *testing.T) { - r := &nonBlockingReader{initialContent: []byte("fppppppppp .class")} - bmb := newMatcher(r, toblockList, 5, maxBufferBestEffort) - p := make([]byte, len(r.initialContent)) - _, err := bmb.Read(p) - if err != nil { - t.Errorf("Failed to read: %v", err) - } - }) - - t.Run("maxBuffer read on closed reader", func(t *testing.T) { - pipeR, pipeW := io.Pipe() - initialContent := []byte("fppppppppp") - go pipeW.Write(initialContent) - bmb := newMatcher(pipeR, toblockList, 5, maxBufferBestEffort) - p := make([]byte, len(initialContent)+10) - pipeR.Close() - _, err := bmb.Read(p) - if err == nil || err != io.ErrClosedPipe { - t.Errorf("Failed to get correct read error: %v", err) - } - }) - - t.Run("maxBuffer read on initial closed reader", func(t *testing.T) { - pipeR, _ := io.Pipe() - initialContent := []byte("fppppppppp") - bmb := newMatcher(pipeR, toblockList, 5, maxBufferBestEffort) - p := make([]byte, len(initialContent)+10) - pipeR.Close() - bmb.Close() - _, err := bmb.Read(p) - if err == nil || err.Error() != "reader closed" { - t.Errorf("Failed to get correct read error: %v", err) - } - }) -} - func TestBlockCreateFilterErrors(t *testing.T) { spec := NewBlock(1024) @@ -181,33 +137,72 @@ func TestBlockCreateFilterErrors(t *testing.T) { } func TestBlock(t *testing.T) { - backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(io.Discard, r.Body) + r.Body.Close() w.WriteHeader(200) w.Write([]byte("OK")) })) defer backend.Close() spec := NewBlock(1024) - args := []interface{}{"foo"} fr := make(filters.Registry) fr.Register(spec) - r := &eskip.Route{Filters: []*eskip.Filter{{Name: spec.Name(), Args: args}}, Backend: backend.URL} - - proxy := proxytest.New(fr, r) - defer proxy.Close() - reqURL, err := url.Parse(proxy.URL) - if err != nil { - t.Errorf("Failed to parse url %s: %v", proxy.URL, err) - } t.Run("block request", func(t *testing.T) { + r := eskip.MustParse(fmt.Sprintf(`* -> blockContent("foo") -> "%s"`, backend.URL)) + proxy := proxytest.New(fr, r...) + defer proxy.Close() + + buf := bytes.NewBufferString("hello foo world") + req, err := http.NewRequest("POST", proxy.URL, buf) + if err != nil { + t.Fatal(err) + } + + rsp, err := proxy.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer rsp.Body.Close() + if rsp.StatusCode != 400 { + t.Errorf("Not Blocked response status code %d", rsp.StatusCode) + } + }) + + t.Run("block request chain first blocks", func(t *testing.T) { + r := eskip.MustParse(fmt.Sprintf(`* -> blockContent("foo") -> blockContent("bar") -> "%s"`, backend.URL)) + proxy := proxytest.New(fr, r...) + defer proxy.Close() + + buf := bytes.NewBufferString("hello foo world") + req, err := http.NewRequest("POST", proxy.URL, buf) + if err != nil { + t.Fatal(err) + } + + rsp, err := proxy.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer rsp.Body.Close() + if rsp.StatusCode != 400 { + t.Errorf("Not Blocked response status code %d", rsp.StatusCode) + } + }) + + t.Run("block request chain second blocks", func(t *testing.T) { + r := eskip.MustParse(fmt.Sprintf(`* -> blockContent("foo") -> blockContent("bar") -> "%s"`, backend.URL)) + proxy := proxytest.New(fr, r...) + defer proxy.Close() + buf := bytes.NewBufferString("hello foo world") - req, err := http.NewRequest("POST", reqURL.String(), buf) + req, err := http.NewRequest("POST", proxy.URL, buf) if err != nil { t.Fatal(err) } - rsp, err := http.DefaultClient.Do(req) + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } @@ -218,13 +213,17 @@ func TestBlock(t *testing.T) { }) t.Run("pass request", func(t *testing.T) { + r := eskip.MustParse(fmt.Sprintf(`* -> blockContent("foo") -> "%s"`, backend.URL)) + proxy := proxytest.New(fr, r...) + defer proxy.Close() + buf := bytes.NewBufferString("hello world") - req, err := http.NewRequest("POST", reqURL.String(), buf) + req, err := http.NewRequest("POST", proxy.URL, buf) if err != nil { t.Fatal(err) } - rsp, err := http.DefaultClient.Do(req) + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } @@ -234,14 +233,81 @@ func TestBlock(t *testing.T) { } }) + t.Run("pass request with filter chain and check content", func(t *testing.T) { + content := "hello world" + + be := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + res, err := io.ReadAll(r.Body) + r.Body.Close() + if err != nil { + w.WriteHeader(500) + w.Write([]byte("Failed to read body")) + return + } + if s := string(res); s != content { + t.Logf("backend received: %q", s) + w.WriteHeader(400) + w.Write([]byte("wrong body")) + return + } + w.WriteHeader(200) + w.Write([]byte("OK")) + })) + defer be.Close() + + r := eskip.MustParse(fmt.Sprintf(`* -> blockContent("foo") -> blockContent("bar") -> "%s"`, be.URL)) + proxy := proxytest.New(fr, r...) + defer proxy.Close() + + buf := bytes.NewBufferString(content) + req, err := http.NewRequest("POST", proxy.URL, buf) + if err != nil { + t.Fatal(err) + } + + rsp, err := proxy.Client().Do(req) + if err != nil { + t.Fatal(err) + } + result, _ := io.ReadAll(rsp.Body) + defer rsp.Body.Close() + if rsp.StatusCode != 200 { + t.Errorf("Blocked response status code %d: %s", rsp.StatusCode, string(result)) + } + }) + t.Run("pass request on empty body", func(t *testing.T) { - buf := bytes.NewBufferString("") - req, err := http.NewRequest("POST", reqURL.String(), buf) + r := eskip.MustParse(fmt.Sprintf(`* -> blockContent("foo") -> "%s"`, backend.URL)) + proxy := proxytest.New(fr, r...) + defer proxy.Close() + + var buf bytes.Buffer + req, err := http.NewRequest("POST", proxy.URL, &buf) + if err != nil { + t.Fatal(err) + } + + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } + defer rsp.Body.Close() + if rsp.StatusCode != 200 { + t.Errorf("Blocked response status code %d", rsp.StatusCode) + } + }) + t.Run("pass request on empty body with filter chain", func(t *testing.T) { + r := eskip.MustParse(fmt.Sprintf(`* -> blockContent("foo") -> blockContent("bar") -> "%s"`, backend.URL)) + proxy := proxytest.New(fr, r...) + defer proxy.Close() - rsp, err := http.DefaultClient.Do(req) + var buf bytes.Buffer + req, err := http.NewRequest("POST", proxy.URL, &buf) + if err != nil { + t.Fatal(err) + } + + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } @@ -251,6 +317,7 @@ func TestBlock(t *testing.T) { } }) } + func TestBlockHex(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(200) @@ -259,26 +326,22 @@ func TestBlockHex(t *testing.T) { defer backend.Close() spec := NewBlockHex(1024) - args := []interface{}{`000a`} fr := make(filters.Registry) fr.Register(spec) - r := &eskip.Route{Filters: []*eskip.Filter{{Name: spec.Name(), Args: args}}, Backend: backend.URL} - - proxy := proxytest.New(fr, r) - defer proxy.Close() - reqURL, err := url.Parse(proxy.URL) - if err != nil { - t.Errorf("Failed to parse url %s: %v", proxy.URL, err) - } t.Run("block request", func(t *testing.T) { + args := []interface{}{`000a`} + r := &eskip.Route{Filters: []*eskip.Filter{{Name: spec.Name(), Args: args}}, Backend: backend.URL} + proxy := proxytest.New(fr, r) + defer proxy.Close() + buf := bytes.NewBufferString("hello \x00\x0afoo world") - req, err := http.NewRequest("POST", reqURL.String(), buf) + req, err := http.NewRequest("POST", proxy.URL, buf) if err != nil { t.Fatal(err) } - rsp, err := http.DefaultClient.Do(req) + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } @@ -289,13 +352,18 @@ func TestBlockHex(t *testing.T) { }) t.Run("block request binary data in request", func(t *testing.T) { + args := []interface{}{`000a`} + r := &eskip.Route{Filters: []*eskip.Filter{{Name: spec.Name(), Args: args}}, Backend: backend.URL} + proxy := proxytest.New(fr, r) + defer proxy.Close() + buf := bytes.NewBuffer([]byte{65, 65, 31, 0, 10, 102, 111, 111, 31}) - req, err := http.NewRequest("POST", reqURL.String(), buf) + req, err := http.NewRequest("POST", proxy.URL, buf) if err != nil { t.Fatal(err) } - rsp, err := http.DefaultClient.Do(req) + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } @@ -306,13 +374,18 @@ func TestBlockHex(t *testing.T) { }) t.Run("pass request", func(t *testing.T) { + args := []interface{}{`000a`} + r := &eskip.Route{Filters: []*eskip.Filter{{Name: spec.Name(), Args: args}}, Backend: backend.URL} + proxy := proxytest.New(fr, r) + defer proxy.Close() + buf := bytes.NewBufferString("hello \x00a\x0a world") - req, err := http.NewRequest("POST", reqURL.String(), buf) + req, err := http.NewRequest("POST", proxy.URL, buf) if err != nil { t.Fatal(err) } - rsp, err := http.DefaultClient.Do(req) + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } @@ -323,13 +396,18 @@ func TestBlockHex(t *testing.T) { }) t.Run("pass request binary data in request", func(t *testing.T) { + args := []interface{}{`000a`} + r := &eskip.Route{Filters: []*eskip.Filter{{Name: spec.Name(), Args: args}}, Backend: backend.URL} + proxy := proxytest.New(fr, r) + defer proxy.Close() + buf := bytes.NewBuffer([]byte{65, 65, 31, 0, 11, 102, 111, 111, 31}) - req, err := http.NewRequest("POST", reqURL.String(), buf) + req, err := http.NewRequest("POST", proxy.URL, buf) if err != nil { t.Fatal(err) } - rsp, err := http.DefaultClient.Do(req) + rsp, err := proxy.Client().Do(req) if err != nil { t.Fatal(err) } @@ -339,67 +417,3 @@ func TestBlockHex(t *testing.T) { } }) } - -func BenchmarkBlock(b *testing.B) { - - fake := func(source string, len int) string { - return strings.Repeat(source[:2], len) // partially matches target - } - - fakematch := func(source string, len int) string { - return strings.Repeat(source, len) // matches target - } - - for _, tt := range []struct { - name string - tomatch []byte - bm []byte - }{ - { - name: "Small Stream without blocking", - tomatch: []byte(".class"), - bm: []byte(fake(".class", 1<<20)), // Test with 1Mib - }, - { - name: "Small Stream with blocking", - tomatch: []byte(".class"), - bm: []byte(fakematch(".class", 1<<20)), - }, - { - name: "Medium Stream without blocking", - tomatch: []byte(".class"), - bm: []byte(fake(".class", 1<<24)), // Test with ~10Mib - }, - { - name: "Medium Stream with blocking", - tomatch: []byte(".class"), - bm: []byte(fakematch(".class", 1<<24)), - }, - { - name: "Large Stream without blocking", - tomatch: []byte(".class"), - bm: []byte(fake(".class", 1<<27)), // Test with ~100Mib - }, - { - name: "Large Stream with blocking", - tomatch: []byte(".class"), - bm: []byte(fakematch(".class", 1<<27)), - }} { - b.Run(tt.name, func(b *testing.B) { - target := &nonBlockingReader{initialContent: tt.bm} - r := &http.Request{ - Body: target, - } - toblockList := []toblockKeys{{str: tt.tomatch}} - bmb := newMatcher(r.Body, toblockList, 2097152, maxBufferBestEffort) - p := make([]byte, len(target.initialContent)) - b.Logf("Number of loops: %b", b.N) - for n := 0; n < b.N; n++ { - _, err := bmb.Read(p) - if err != nil { - return - } - } - }) - } -} diff --git a/filters/block/matcher.go b/filters/block/matcher.go deleted file mode 100644 index b64050328b..0000000000 --- a/filters/block/matcher.go +++ /dev/null @@ -1,209 +0,0 @@ -package block - -import ( - "bytes" - "errors" - "io" - "sync" - - "github.com/zalando/skipper/metrics" - "github.com/zalando/skipper/proxy" -) - -type toblockKeys struct{ str []byte } - -const ( - readBufferSize uint64 = 8192 -) - -type maxBufferHandling int - -const ( - maxBufferBestEffort maxBufferHandling = iota - maxBufferAbort -) - -// matcher provides a reader that wraps an input reader, and blocks the request -// if a pattern was found. -// -// It reads enough data until at least a complete match of the -// pattern is met or the maxBufferSize is reached. When the pattern matches the entire -// buffered input, the replaced content is returned to the caller when maxBufferSize is -// reached. This also means that more replacements can happen than if we edited the -// entire content in one piece, but this is necessary to be able to use the matcher for -// input with unknown length. -// -// When the maxBufferHandling is set to maxBufferAbort, then the streaming is aborted -// and the rest of the payload is dropped. -// -// To limit the number of repeated scans over the buffered data, the size of the -// additional data read from the input grows exponentially with every iteration that -// didn't result with any matched data blocked. If there was any matched data -// the read size is reset to the initial value. -// -// When the input returns an error, e.g. EOF, the matcher finishes matching the buffered -// data, blocks or return it to the caller. -// -// When the matcher is closed, it doesn't read anymore from the input or return any -// buffered data. If the input implements io.Closer, closing the matcher closes the -// input, too. -type matcher struct { - once sync.Once - input io.ReadCloser - toblockList []toblockKeys - maxBufferSize uint64 - maxBufferHandling maxBufferHandling - readBuffer []byte - - ready *bytes.Buffer - pending *bytes.Buffer - - metrics metrics.Metrics - - err error - closed bool -} - -var ( - ErrMatcherBufferFull = errors.New("matcher buffer full") -) - -func newMatcher( - input io.ReadCloser, - toblockList []toblockKeys, - maxBufferSize uint64, - mbh maxBufferHandling, -) *matcher { - - rsize := readBufferSize - if maxBufferSize < rsize { - rsize = maxBufferSize - } - - return &matcher{ - once: sync.Once{}, - input: input, - toblockList: toblockList, - maxBufferSize: maxBufferSize, - maxBufferHandling: mbh, - readBuffer: make([]byte, rsize), - pending: bytes.NewBuffer(nil), - ready: bytes.NewBuffer(nil), - metrics: metrics.Default, - } -} - -func (m *matcher) readNTimes(times int) (bool, error) { - var consumedInput bool - for i := 0; i < times; i++ { - n, err := m.input.Read(m.readBuffer) - m.pending.Write(m.readBuffer[:n]) - if n > 0 { - consumedInput = true - } - - if err != nil { - return consumedInput, err - } - - } - - return consumedInput, nil -} - -func (m *matcher) match(b []byte) (int, error) { - var consumed int - - for _, s := range m.toblockList { - if bytes.Contains(b, s.str) { - b = nil - return 0, proxy.ErrBlocked - } - } - consumed += len(b) - return consumed, nil - -} - -func (m *matcher) fill(requested int) error { - readSize := 1 - for m.ready.Len() < requested { - consumedInput, err := m.readNTimes(readSize) - if !consumedInput { - io.CopyBuffer(m.ready, m.pending, m.readBuffer) - return err - } - - if uint64(m.pending.Len()) > m.maxBufferSize { - switch m.maxBufferHandling { - case maxBufferAbort: - return ErrMatcherBufferFull - default: - _, err := m.match(m.pending.Bytes()) - if err != nil { - return err - } - m.pending.Reset() - readSize = 1 - } - } - - readSize *= 2 - } - return nil -} - -func (m *matcher) Read(p []byte) (int, error) { - if m.closed { - return 0, ErrClosed - } - - if m.ready.Len() == 0 && m.err != nil { - return 0, m.err - } - - if m.ready.Len() < len(p) { - m.err = m.fill(len(p)) - } - - if m.err == ErrMatcherBufferFull { - return 0, ErrMatcherBufferFull - } - - if m.err == proxy.ErrBlocked { - m.metrics.IncCounter("blocked.requests") - return 0, proxy.ErrBlocked - } - - n, _ := m.ready.Read(p) - - if n == 0 && len(p) > 0 && m.err != nil { - return 0, m.err - } - - n, err := m.match(p) - - if err != nil { - m.closed = true - - if err == proxy.ErrBlocked { - m.metrics.IncCounter("blocked.requests") - } - - return 0, err - } - - return n, nil -} - -// Close closes the undelrying reader if it implements io.Closer. -func (m *matcher) Close() error { - var err error - m.once.Do(func() { - m.closed = true - if c, ok := m.input.(io.Closer); ok { - err = c.Close() - } - }) - return err -} diff --git a/io/helper_test.go b/io/helper_test.go new file mode 100644 index 0000000000..16f58cd679 --- /dev/null +++ b/io/helper_test.go @@ -0,0 +1,15 @@ +package io + +import "bytes" + +type mybuf struct { + buf *bytes.Buffer +} + +func (mybuf) Close() error { + return nil +} + +func (b mybuf) Read(p []byte) (int, error) { + return b.buf.Read(p) +} diff --git a/io/read_stream.go b/io/read_stream.go new file mode 100644 index 0000000000..49b176dcff --- /dev/null +++ b/io/read_stream.go @@ -0,0 +1,215 @@ +package io + +import ( + "bytes" + "context" + "errors" + "io" + "sync" +) + +var ( + ErrClosed = errors.New("reader closed") + ErrBlocked = errors.New("blocked string match found in stream") +) + +const ( + defaultReadBufferSize uint64 = 8192 +) + +type MaxBufferHandling int + +const ( + MaxBufferBestEffort MaxBufferHandling = iota + MaxBufferAbort +) + +type matcher struct { + ctx context.Context + once sync.Once + input io.ReadCloser + f func([]byte) (int, error) + maxBufferSize uint64 + maxBufferHandling MaxBufferHandling + readBuffer []byte + + ready *bytes.Buffer + pending *bytes.Buffer + + err error + closed bool +} + +var ( + ErrMatcherBufferFull = errors.New("matcher buffer full") +) + +func newMatcher( + ctx context.Context, + input io.ReadCloser, + f func([]byte) (int, error), + maxBufferSize uint64, + mbh MaxBufferHandling, +) *matcher { + + rsize := defaultReadBufferSize + if maxBufferSize < rsize { + rsize = maxBufferSize + } + + return &matcher{ + ctx: ctx, + once: sync.Once{}, + input: input, + f: f, + maxBufferSize: maxBufferSize, + maxBufferHandling: mbh, + readBuffer: make([]byte, rsize), + pending: bytes.NewBuffer(nil), + ready: bytes.NewBuffer(nil), + } +} + +func (m *matcher) readNTimes(times int) (bool, error) { + var consumedInput bool + for i := 0; i < times; i++ { + n, err := m.input.Read(m.readBuffer) + _, err2 := m.pending.Write(m.readBuffer[:n]) + + if n > 0 { + consumedInput = true + } + if err != nil { + return consumedInput, err + } + if err2 != nil { + return consumedInput, err2 + } + } + return consumedInput, nil +} + +func (m *matcher) fill(requested int) error { + readSize := 1 + for m.ready.Len() < requested { + consumedInput, err := m.readNTimes(readSize) + if !consumedInput { + io.CopyBuffer(m.ready, m.pending, m.readBuffer) + return err + } + + if uint64(m.pending.Len()) > m.maxBufferSize { + switch m.maxBufferHandling { + case MaxBufferAbort: + return ErrMatcherBufferFull + default: + select { + case <-m.ctx.Done(): + m.Close() + return m.ctx.Err() + default: + } + _, err := m.f(m.pending.Bytes()) + if err != nil { + return err + } + m.pending.Reset() + readSize = 1 + } + } + + readSize *= 2 + } + return nil +} + +func (m *matcher) Read(p []byte) (int, error) { + if m.closed { + return 0, ErrClosed + } + + if m.ready.Len() == 0 && m.err != nil { + return 0, m.err + } + + if m.ready.Len() < len(p) { + m.err = m.fill(len(p)) + } + + switch m.err { + case ErrMatcherBufferFull, ErrBlocked: + return 0, m.err + } + + n, _ := m.ready.Read(p) + if n == 0 && len(p) > 0 && m.err != nil { + return 0, m.err + } + p = p[:n] + + select { + case <-m.ctx.Done(): + m.Close() + return 0, m.ctx.Err() + default: + } + + n, err := m.f(p) + if err != nil { + m.closed = true + return 0, err + } + return n, nil +} + +// Close closes the underlying reader if it implements io.Closer. +func (m *matcher) Close() error { + var err error + m.once.Do(func() { + m.closed = true + if c, ok := m.input.(io.Closer); ok { + err = c.Close() + } + }) + return err +} + +/* + Wants: + - [x] filters can read the body content for example WAF scoring + - [ ] filters can change the body content for example sedRequest() + - [x] filters need to be chainable (support -> ) + - [x] filters need to be able to stop streaming to request blockContent() or WAF deny() + + TODO(sszuecs): + + 1) major optimization: use registry pattern and have only one body + wrapped for concatenating readers and run all f() in a loop, so + streaming does not happen for all but once for all + readers. Important if one write is between two readers we can not + do this, so we need to detect this case. + + 3) in case we ErrBlock, then we break the loop or cancel the + context to stop processing. The registry control layer should be + able to stop all processing. + +*/ + +type BufferOptions struct { + MaxBufferHandling MaxBufferHandling + ReadBufferSize uint64 +} + +// InspectReader wraps the given ReadCloser such that the given +// function f can inspect the streaming while streaming to the +// target. A target can be any io.ReadCloser, so for example the +// request body to the backend or the response body to the +// client. InspectReader applies given BufferOptions to the matcher. +// +// NOTE: This function is *experimental* and will likely change or disappear in the future. +func InspectReader(ctx context.Context, bo BufferOptions, f func([]byte) (int, error), rc io.ReadCloser) io.ReadCloser { + if bo.ReadBufferSize < 1 { + bo.ReadBufferSize = defaultReadBufferSize + } + return newMatcher(ctx, rc, f, bo.ReadBufferSize, bo.MaxBufferHandling) +} diff --git a/io/read_stream_test.go b/io/read_stream_test.go new file mode 100644 index 0000000000..cf5272e3d1 --- /dev/null +++ b/io/read_stream_test.go @@ -0,0 +1,429 @@ +package io + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "syscall" + "testing" + "time" +) + +type toBlockKeys struct{ Str []byte } + +func blockMatcher(matches []toBlockKeys) func(b []byte) (int, error) { + return func(b []byte) (int, error) { + var consumed int + for _, s := range matches { + if bytes.Contains(b, s.Str) { + return 0, ErrBlocked + } + } + consumed += len(b) + return consumed, nil + } +} + +func TestHttpBodyReadOnly(t *testing.T) { + sent := "hell0 foo bar" + + okBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := make([]byte, 0, 1024) + buf := bytes.NewBuffer(b) + n, err := io.Copy(buf, r.Body) + if err != nil { + t.Fatalf("Failed to read body on backend receiver: %v", err) + } + + t.Logf("read(%d): %s", n, buf) + if got := buf.String(); got != sent { + t.Fatalf("Failed to get request body in okbackend. want: %q, got: %q", sent, got) + } + w.WriteHeader(200) + // w.Write([]byte("OK")) + w.Write(b[:n]) + })) + defer okBackend.Close() + + blockedBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := make([]byte, 1024) + buf := bytes.NewBuffer(b) + _, err := io.Copy(buf, r.Body) + + // body started to stream but was cut by sender + if err != io.ErrUnexpectedEOF { + t.Logf("expected 'io.ErrUnexpectedEOF' got: %v", err) + } + + w.WriteHeader(200) + w.Write([]byte("OK")) + })) + defer blockedBackend.Close() + + t.Run("single block matcher without match", func(t *testing.T) { + var b mybuf + b.buf = bytes.NewBufferString(sent) + + body := InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("no match")}}), b) + defer body.Close() + rsp, err := (&http.Client{}).Post(okBackend.URL, "text/plain", body) + if err != nil { + t.Fatalf("Failed to do POST request: %v", err) + } + + if rsp.StatusCode != http.StatusOK { + t.Fatalf("Failed to get the expected status code 200, got: %d", rsp.StatusCode) + } + var buf bytes.Buffer + io.Copy(&buf, rsp.Body) + rsp.Body.Close() + if got := buf.String(); got != sent { + t.Fatalf("Failed to get %q, got %q", sent, got) + } + }) + + t.Run("double block matcher without match", func(t *testing.T) { + var b mybuf + b.buf = bytes.NewBufferString(sent) + + bod := InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("no-match")}}), b) + defer bod.Close() + body := InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("no match")}}), bod) + defer body.Close() + rsp, err := (&http.Client{}).Post(okBackend.URL, "text/plain", body) + if err != nil { + t.Fatalf("Failed to POST request: %v", err) + } + + if rsp.StatusCode != http.StatusOK { + t.Fatalf("Failed to get 200 status code, got: %v", rsp.StatusCode) + } + var buf bytes.Buffer + io.Copy(&buf, rsp.Body) + rsp.Body.Close() + if got := buf.String(); got != sent { + t.Fatalf("Failed to get %q, got %q", sent, got) + } + }) + + t.Run("single block matcher with match", func(t *testing.T) { + + var b mybuf + b.buf = bytes.NewBufferString("hell0 foo bar") + + body := InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("foo")}}), b) + defer body.Close() + rsp, err := (&http.Client{}).Post(blockedBackend.URL, "text/plain", body) + if !errors.Is(err, ErrBlocked) { + if rsp != nil { + t.Errorf("rsp should be nil, status code: %d", rsp.StatusCode) + } + t.Fatalf("Expected POST request to be blocked, got err: %v", err) + } + }) + + t.Run("double block matcher with first match", func(t *testing.T) { + var b mybuf + b.buf = bytes.NewBufferString("hell0 foo bar") + + body := InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("foo")}}), b) + body = InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("no match")}}), body) + defer body.Close() + rsp, err := (&http.Client{}).Post(blockedBackend.URL, "text/plain", body) + + if !errors.Is(err, ErrBlocked) { + if rsp != nil { + t.Errorf("rsp should be nil, status code: %d", rsp.StatusCode) + } + t.Fatalf("Expected POST request to be blocked, got err: %v", err) + } + }) + + t.Run("double block matcher with second match", func(t *testing.T) { + var b mybuf + b.buf = bytes.NewBufferString("hell0 foo bar") + + body := InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("no match")}}), b) + body = InspectReader(context.Background(), BufferOptions{}, blockMatcher([]toBlockKeys{{Str: []byte("bar")}}), body) + defer body.Close() + rsp, err := (&http.Client{}).Post(blockedBackend.URL, "text/plain", body) + + if !errors.Is(err, ErrBlocked) { + if rsp != nil { + t.Errorf("rsp should be nil, status code: %d", rsp.StatusCode) + } + t.Fatalf("Expected POST request to be blocked, got err: %v", err) + } + }) + +} + +type nonBlockingReader struct { + initialContent []byte +} + +func (r *nonBlockingReader) Read(p []byte) (int, error) { + n := copy(p, r.initialContent) + r.initialContent = r.initialContent[n:] + return n, nil +} + +func (r *nonBlockingReader) Close() error { + return nil +} + +func (hr *hookReader) Close() error { + return nil +} + +type slowBlockingReader struct { + initialContent []byte +} + +func (r *slowBlockingReader) Read(p []byte) (int, error) { + time.Sleep(250 * time.Millisecond) + n := copy(p, r.initialContent) + r.initialContent = r.initialContent[n:] + return n, nil +} + +func (r *slowBlockingReader) Close() error { + return nil +} + +type hookReader struct { + initialContent []byte + nbytes int + hook func() + counter int +} + +func (hr *hookReader) Read(p []byte) (int, error) { + println("Read()", len(p)) + if len(hr.initialContent) < hr.nbytes || len(p) < hr.nbytes { + return 0, nil + } + n := copy(p, hr.initialContent[:hr.nbytes]) + hr.initialContent = hr.initialContent[n:] + hr.hook() + if hr.counter > 0 { + hr.counter-- + return n, syscall.EAGAIN + } + return n, nil +} + +func TestMatcherFuncError(t *testing.T) { + t.Run("test canceled while matcher is running", func(t *testing.T) { + rc := &hookReader{ + initialContent: []byte("0123456789abcdef"), + hook: func() { + time.Sleep(11 * time.Millisecond) + }, + nbytes: 4, + } + + ctx, done := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer done() + + f := func(p []byte) (int, error) { + return len(p), nil + } + + m := newMatcher(ctx, rc, f, 1024, MaxBufferBestEffort) + + p := make([]byte, 8) + _, err := m.Read(p) + + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Failed to read: %v", err) + } + }) + + t.Run("test read error", func(t *testing.T) { + rc := &hookReader{ + initialContent: []byte("0123456789abcdef"), + hook: func() { + time.Sleep(5 * time.Millisecond) + }, + nbytes: 4, + } + errTest := fmt.Errorf("we test an error") + + f := func(p []byte) (int, error) { + if len(p) == 8 { + return 0, errTest + } + return len(p), nil + } + + m := newMatcher(context.Background(), rc, f, 1024, MaxBufferBestEffort) + + p := make([]byte, 8) + _, err := m.Read(p) + + if !errors.Is(err, errTest) { + t.Fatalf("Failed to read: %v", err) + } + }) + + t.Run("test pending read func error", func(t *testing.T) { + rc := &hookReader{ + initialContent: []byte("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"), + hook: func() {}, + nbytes: 4, + } + errTest := fmt.Errorf("we test an error") + + f := func(p []byte) (int, error) { + switch len(p) { + case 12: + return 0, errTest + } + return len(p), nil + } + + m := newMatcher(context.Background(), rc, f, 8, MaxBufferBestEffort) + + p := make([]byte, 8) + _, err := m.Read(p) + + if !errors.Is(err, errTest) { + t.Fatalf("Failed to read: %v", err) + } + }) +} + +// TODO(sszuecs): test all error cases for matcher, the following we had for blockContent() filter +func TestMatcherErrorCases(t *testing.T) { + toblockList := []toBlockKeys{{Str: []byte(".class")}} + t.Run("maxBufferAbort", func(t *testing.T) { + r := &nonBlockingReader{initialContent: []byte("fppppppppp .class")} + bmb := newMatcher(context.Background(), r, blockMatcher(toblockList), 5, MaxBufferAbort) + p := make([]byte, len(r.initialContent)) + _, err := bmb.Read(p) + if err != ErrMatcherBufferFull { + t.Errorf("Failed to get expected error %v, got: %v", ErrMatcherBufferFull, err) + } + }) + + t.Run("maxBuffer", func(t *testing.T) { + r := &nonBlockingReader{initialContent: []byte("fppppppppp .class")} + bmb := newMatcher(context.Background(), r, blockMatcher(toblockList), 5, MaxBufferBestEffort) + p := make([]byte, len(r.initialContent)) + _, err := bmb.Read(p) + if err != nil { + t.Errorf("Failed to read: %v", err) + } + }) + + t.Run("cancel read", func(t *testing.T) { + r := &slowBlockingReader{initialContent: []byte("fppppppppp .class")} + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(10*time.Millisecond)) + defer cancel() + bmb := newMatcher(ctx, r, blockMatcher(toblockList), 5, MaxBufferBestEffort) + p := make([]byte, len(r.initialContent)) + _, err := bmb.Read(p) + if err == nil { + t.Errorf("Failed to cancel read: %v", err) + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Failed to get deadline exceeded, got: %T", err) + } + }) + + t.Run("maxBuffer read on closed reader", func(t *testing.T) { + pipeR, pipeW := io.Pipe() + initialContent := []byte("fppppppppp") + go pipeW.Write(initialContent) + bmb := newMatcher(context.Background(), pipeR, blockMatcher(toblockList), 5, MaxBufferBestEffort) + p := make([]byte, len(initialContent)+10) + pipeR.Close() + _, err := bmb.Read(p) + if err == nil || err != io.ErrClosedPipe { + t.Errorf("Failed to get correct read error: %v", err) + } + }) + + t.Run("maxBuffer read on initial closed reader", func(t *testing.T) { + pipeR, _ := io.Pipe() + initialContent := []byte("fppppppppp") + bmb := newMatcher(context.Background(), pipeR, blockMatcher(toblockList), 5, MaxBufferBestEffort) + p := make([]byte, len(initialContent)+10) + pipeR.Close() + bmb.Close() + _, err := bmb.Read(p) + if err == nil || err.Error() != "reader closed" { + t.Errorf("Failed to get correct read error: %v", err) + } + }) +} + +func BenchmarkBlock(b *testing.B) { + + fake := func(source string, len int) string { + return strings.Repeat(source[:2], len) // partially matches target + } + + fakematch := func(source string, len int) string { + return strings.Repeat(source, len) // matches target + } + + for _, tt := range []struct { + name string + tomatch []byte + bm []byte + }{ + { + name: "Small Stream without blocking", + tomatch: []byte(".class"), + bm: []byte(fake(".class", 1<<20)), // Test with 1Mib + }, + { + name: "Small Stream with blocking", + tomatch: []byte(".class"), + bm: []byte(fakematch(".class", 1<<20)), + }, + { + name: "Medium Stream without blocking", + tomatch: []byte(".class"), + bm: []byte(fake(".class", 1<<24)), // Test with ~10Mib + }, + { + name: "Medium Stream with blocking", + tomatch: []byte(".class"), + bm: []byte(fakematch(".class", 1<<24)), + }, + { + name: "Large Stream without blocking", + tomatch: []byte(".class"), + bm: []byte(fake(".class", 1<<27)), // Test with ~100Mib + }, + { + name: "Large Stream with blocking", + tomatch: []byte(".class"), + bm: []byte(fakematch(".class", 1<<27)), + }} { + b.Run(tt.name, func(b *testing.B) { + target := &nonBlockingReader{initialContent: tt.bm} + r := &http.Request{ + Body: target, + } + toblockList := []toBlockKeys{{Str: tt.tomatch}} + bmb := newMatcher(context.Background(), r.Body, blockMatcher(toblockList), 2097152, MaxBufferBestEffort) + p := make([]byte, len(target.initialContent)) + b.Logf("Number of loops: %b", b.N) + for n := 0; n < b.N; n++ { + _, err := bmb.Read(p) + if err != nil { + return + } + } + }) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 7b6923cb2e..5d591323fb 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -33,6 +33,7 @@ import ( filterslog "github.com/zalando/skipper/filters/log" ratelimitfilters "github.com/zalando/skipper/filters/ratelimit" tracingfilter "github.com/zalando/skipper/filters/tracing" + skpio "github.com/zalando/skipper/io" "github.com/zalando/skipper/loadbalancer" "github.com/zalando/skipper/logging" "github.com/zalando/skipper/metrics" @@ -267,7 +268,6 @@ const ( ) var ( - ErrBlocked = errors.New("blocked string match found in body") errRouteLookupFailed = &proxyError{err: errRouteLookup} errCircuitBreakerOpen = &proxyError{ err: errors.New("circuit breaker open"), @@ -890,7 +890,7 @@ func (p *Proxy) makeBackendRequest(ctx *context, requestContext stdlibcontext.Co ctx.proxySpan.LogKV("http_roundtrip", EndEvent) if err != nil { - if errors.Is(err, ErrBlocked) { + if errors.Is(err, skpio.ErrBlocked) { p.tracing.setTag(ctx.proxySpan, BlockTag, true) p.tracing.setTag(ctx.proxySpan, HTTPStatusCodeTag, uint16(http.StatusBadRequest)) return nil, &proxyError{err: err, code: http.StatusBadRequest}