diff --git a/bundle/bundle.go b/bundle/bundle.go index 72927123a8..035a1a9c81 100644 --- a/bundle/bundle.go +++ b/bundle/bundle.go @@ -1465,14 +1465,25 @@ func preProcessBundle(loader DirectoryLoader, skipVerify bool, sizeLimitBytes in } func readFile(f *Descriptor, sizeLimitBytes int64) (bytes.Buffer, error) { + if bb, ok := f.reader.(*bytes.Buffer); ok { + _ = f.Close() // always close, even on error + + if int64(bb.Len()) >= sizeLimitBytes { + return *bb, fmt.Errorf("bundle file '%v' size (%d bytes) exceeded max size (%v bytes)", + strings.TrimPrefix(f.Path(), "/"), bb.Len(), sizeLimitBytes-1) + } + + return *bb, nil + } + var buf bytes.Buffer n, err := f.Read(&buf, sizeLimitBytes) - f.Close() // always close, even on error + _ = f.Close() // always close, even on error if err != nil && err != io.EOF { return buf, err } else if err == nil && n >= sizeLimitBytes { - return buf, fmt.Errorf("bundle file '%v' exceeded max size (%v bytes)", strings.TrimPrefix(f.Path(), "/"), sizeLimitBytes-1) + return buf, fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(f.Path(), "/"), n, sizeLimitBytes-1) } return buf, nil diff --git a/bundle/bundle_test.go b/bundle/bundle_test.go index 4e947727b3..6c2fa6f3e3 100644 --- a/bundle/bundle_test.go +++ b/bundle/bundle_test.go @@ -104,7 +104,7 @@ func TestReadWithSizeLimit(t *testing.T) { br := NewCustomReader(loader).WithSizeLimitBytes(4) _, err := br.Read() - if err == nil || err.Error() != "bundle file 'data.json' exceeded max size (4 bytes)" { + if err == nil || err.Error() != "bundle file 'data.json' size (5 bytes) exceeded max size (4 bytes)" { t.Fatal("expected error but got:", err) } @@ -116,7 +116,7 @@ func TestReadWithSizeLimit(t *testing.T) { br = NewCustomReader(loader).WithSizeLimitBytes(4) _, err = br.Read() - if err == nil || err.Error() != "bundle file '.signatures.json' exceeded max size (4 bytes)" { + if err == nil || err.Error() != "bundle file '.signatures.json' size (5 bytes) exceeded max size (4 bytes)" { t.Fatal("expected error but got:", err) } } diff --git a/bundle/file.go b/bundle/file.go index b3329b6a43..62ffeeff5f 100644 --- a/bundle/file.go +++ b/bundle/file.go @@ -17,6 +17,8 @@ import ( "github.com/open-policy-agent/opa/storage" ) +const maxSizeLimitBytesErrMsg = "bundle file %s size (%d bytes) exceeds configured size_limit_bytes (%d bytes)" + // Descriptor contains information about a file and // can be used to read the file contents. type Descriptor struct { @@ -123,14 +125,16 @@ type DirectoryLoader interface { NextFile() (*Descriptor, error) WithFilter(filter filter.LoaderFilter) DirectoryLoader WithPathFormat(PathFormat) DirectoryLoader + WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader } type dirLoader struct { - root string - files []string - idx int - filter filter.LoaderFilter - pathFormat PathFormat + root string + files []string + idx int + filter filter.LoaderFilter + pathFormat PathFormat + maxSizeLimitBytes int64 } // Normalize root directory, ex "./src/bundle" -> "src/bundle" @@ -171,6 +175,12 @@ func (d *dirLoader) WithPathFormat(pathFormat PathFormat) DirectoryLoader { return d } +// WithSizeLimitBytes specifies the maximum size of any file in the directory to read +func (d *dirLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { + d.maxSizeLimitBytes = sizeLimitBytes + return d +} + func formatPath(fileName string, root string, pathFormat PathFormat) string { switch pathFormat { case SlashRooted: @@ -206,6 +216,9 @@ func (d *dirLoader) NextFile() (*Descriptor, error) { if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { return nil } + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(path, "/"), info.Size(), d.maxSizeLimitBytes) + } d.files = append(d.files, path) } else if info != nil && info.Mode().IsDir() { if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { @@ -235,14 +248,15 @@ func (d *dirLoader) NextFile() (*Descriptor, error) { } type tarballLoader struct { - baseURL string - r io.Reader - tr *tar.Reader - files []file - idx int - filter filter.LoaderFilter - skipDir map[string]struct{} - pathFormat PathFormat + baseURL string + r io.Reader + tr *tar.Reader + files []file + idx int + filter filter.LoaderFilter + skipDir map[string]struct{} + pathFormat PathFormat + maxSizeLimitBytes int64 } type file struct { @@ -285,6 +299,12 @@ func (t *tarballLoader) WithPathFormat(pathFormat PathFormat) DirectoryLoader { return t } +// WithSizeLimitBytes specifies the maximum size of any file in the tarball to read +func (t *tarballLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { + t.maxSizeLimitBytes = sizeLimitBytes + return t +} + // NextFile iterates to the next file in the directory tree // and returns a file Descriptor for the file. func (t *tarballLoader) NextFile() (*Descriptor, error) { @@ -306,6 +326,7 @@ func (t *tarballLoader) NextFile() (*Descriptor, error) { for { header, err := t.tr.Next() + if err == io.EOF { break } @@ -343,6 +364,10 @@ func (t *tarballLoader) NextFile() (*Descriptor, error) { } } + if t.maxSizeLimitBytes > 0 && header.Size > t.maxSizeLimitBytes { + return nil, fmt.Errorf(maxSizeLimitBytesErrMsg, header.Name, header.Size, t.maxSizeLimitBytes) + } + f := file{name: header.Name} var buf bytes.Buffer diff --git a/bundle/file_test.go b/bundle/file_test.go index 191e806d9a..d44db53745 100644 --- a/bundle/file_test.go +++ b/bundle/file_test.go @@ -141,6 +141,29 @@ func TestDirectoryLoader(t *testing.T) { }) } +func TestTarballLoaderWithMaxSizeBytesLimit(t *testing.T) { + rootDir := t.TempDir() + tarballFile := filepath.Join(rootDir, "archive.tar.gz") + + f := testGetTarballFile(t, rootDir) + + loader := NewTarballLoaderWithBaseURL(f, tarballFile).WithSizeLimitBytes(5) + + defer f.Close() + + _, err := loader.NextFile() + if err == nil { + t.Fatal("Expected error but got nil") + } + + // Order of iteration over files in the tarball aren't necessarily in a deterministic order, + // but luckily we have 2 files of 18 bytes. Just skip checking for the name here. + expected := "size (18 bytes) exceeds configured size_limit_bytes (5 bytes)" + + if !strings.Contains(err.Error(), expected) { + t.Errorf("Expected %q but got %v", expected, err) + } +} func TestTarballLoaderWithFilter(t *testing.T) { files := map[string]string{ diff --git a/bundle/filefs.go b/bundle/filefs.go index ff112082e2..e8767e1bae 100644 --- a/bundle/filefs.go +++ b/bundle/filefs.go @@ -19,12 +19,13 @@ const ( type dirLoaderFS struct { sync.Mutex - filesystem fs.FS - files []string - idx int - filter filter.LoaderFilter - root string - pathFormat PathFormat + filesystem fs.FS + files []string + idx int + filter filter.LoaderFilter + root string + pathFormat PathFormat + maxSizeLimitBytes int64 } // NewFSLoader returns a basic DirectoryLoader implementation @@ -61,6 +62,10 @@ func (d *dirLoaderFS) walkDir(path string, dirEntry fs.DirEntry, err error) erro return nil } + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf("file %s size %d exceeds limit of %d", path, info.Size(), d.maxSizeLimitBytes) + } + d.files = append(d.files, path) } else if dirEntry.Type().IsDir() { if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { @@ -83,6 +88,12 @@ func (d *dirLoaderFS) WithPathFormat(pathFormat PathFormat) DirectoryLoader { return d } +// WithSizeLimitBytes specifies the maximum size of any file in the filesystem directory to read +func (d *dirLoaderFS) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { + d.maxSizeLimitBytes = sizeLimitBytes + return d +} + // NextFile iterates to the next file in the directory tree // and returns a file Descriptor for the file. func (d *dirLoaderFS) NextFile() (*Descriptor, error) { diff --git a/download/download.go b/download/download.go index 2c52efd53d..f5b3c9c7c6 100644 --- a/download/download.go +++ b/download/download.go @@ -326,6 +326,12 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download loader = bundle.NewTarballLoaderWithBaseURL(r, baseURL) } + // Setting the size limit on the loader allows early exit in the case + // of any file exceeding the limit, without the file getting loaded + if d.sizeLimitBytes != nil { + loader = loader.WithSizeLimitBytes(*d.sizeLimitBytes) + } + etag := resp.Header.Get("ETag") reader := bundle.NewCustomReader(loader). @@ -335,6 +341,7 @@ func (d *Downloader) download(ctx context.Context, m metrics.Metrics) (*download WithLazyLoadingMode(d.lazyLoadingMode). WithBundleName(d.bundleName). WithBundlePersistence(d.persist) + if d.sizeLimitBytes != nil { reader = reader.WithSizeLimitBytes(*d.sizeLimitBytes) }