diff --git a/progress/reader.go b/progress/reader.go index a26cf72..d7dc7dd 100644 --- a/progress/reader.go +++ b/progress/reader.go @@ -2,9 +2,9 @@ package progress import "io" -// Reader is an [io.Reader] that reports the number of bytes read from it via a -// callback. The callback is called at most every "updateInterval" bytes. The -// updateInterval can be set using the [Reader.WithUpdateInterval] method. +// Reader is an [io.ReadCloser] that reports the number of bytes read from it +// via a callback. The callback is called at most every "updateInterval" bytes. +// The updateInterval can be set using the [Reader.WithUpdateInterval] method. // // The following is an example of how to use [Reader] to report the progress of // reading from a file: @@ -44,4 +44,11 @@ func (r *Reader) Read(p []byte) (n int, err error) { return n, nil } -var _ io.Reader = (*Reader)(nil) +func (r *Reader) Close() error { + if closer, ok := r.inner.(io.Closer); ok { + return closer.Close() + } + return nil +} + +var _ io.ReadCloser = (*Reader)(nil) diff --git a/progress/reader_test.go b/progress/reader_test.go index bcdeedd..9d387b9 100644 --- a/progress/reader_test.go +++ b/progress/reader_test.go @@ -2,6 +2,7 @@ package progress_test import ( "bytes" + "errors" "io" "testing" @@ -24,3 +25,35 @@ func TestReader(t *testing.T) { assert.Greater(t, len(progressUpdates), 1) assert.IsIncreasing(t, progressUpdates) } + +type testReader struct { + *bytes.Reader + closed bool +} + +func (r *testReader) Close() error { + if r.closed { + return errors.New("already closed") + } + r.closed = true + return nil +} + +func TestReadCloser(t *testing.T) { + readCloser := &testReader{Reader: bytes.NewReader(bytes.Repeat([]byte{42}, 1024*1024))} + + var progressUpdates []int + progressReader := progress.NewReader(readCloser, func(readBytes int) { + progressUpdates = append(progressUpdates, readBytes) + }) + + data, err := io.ReadAll(progressReader) + assert.NoError(t, err) + assert.Equal(t, data, bytes.Repeat([]byte{42}, 1024*1024)) + + assert.Greater(t, len(progressUpdates), 1) + assert.IsIncreasing(t, progressUpdates) + + assert.NoError(t, progressReader.Close()) + assert.ErrorContains(t, progressReader.Close(), "already closed") +}