Skip to content

Commit

Permalink
feat: multi reader
Browse files Browse the repository at this point in the history
Signed-off-by: Eray Ates <[email protected]>
  • Loading branch information
rytsh committed Feb 2, 2025
1 parent 49502a9 commit 19e276b
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 11 deletions.
18 changes: 18 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.DEFAULT_GOAL := help

.PHONY: lint
lint: ## Lint Go files
@GOPATH="$(shell dirname $(PWD))" golangci-lint run ./...

.PHONY: test
test: ## Run unit tests
@go test -v -race ./...

.PHONY: coverage
coverage: ## Run unit tests with coverage
@go test -v -race -cover -coverpkg=./... -coverprofile=coverage.out -covermode=atomic ./...
@go tool cover -func=coverage.out

.PHONY: help
help: ## Display this help screen
@grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
44 changes: 44 additions & 0 deletions drain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package klient

import "io"

type optionDrain struct {
Limit int64
}

func newOptionDrain(opts []OptionDrain) *optionDrain {
o := new(optionDrain)
for _, opt := range opts {
opt(o)
}

if o.Limit == 0 {
o.Limit = ResponseErrLimit
}

return o
}

type OptionDrain func(*optionDrain)

// WithDrainLimit sets the limit of the content to be read.
// If the limit is less than 0, it will read all the content.
func WithDrainLimit(limit int64) OptionDrain {
return func(o *optionDrain) {
o.Limit = limit
}
}

// DrainBody reads the limited content of r and then closes the underlying io.ReadCloser.
func DrainBody(body io.ReadCloser, opts ...OptionDrain) {
o := newOptionDrain(opts)

defer body.Close()
if o.Limit < 0 {
_, _ = io.Copy(io.Discard, body)

return
}

_, _ = io.Copy(io.Discard, io.LimitReader(body, o.Limit))
}
3 changes: 2 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package klient
import (
"errors"
"fmt"
"io"
"net/http"
)

Expand All @@ -29,7 +30,7 @@ func (e *ResponseError) Error() string {

// ErrResponse returns an error with the limited response body.
func ErrResponse(resp *http.Response) error {
partialBody := LimitedResponse(resp)
partialBody, _ := io.ReadAll(io.LimitReader(resp.Body, ResponseErrLimit))

return &ResponseError{
StatusCode: resp.StatusCode,
Expand Down
67 changes: 67 additions & 0 deletions reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package klient

import (
"context"
"errors"
"io"
)

type MultiReader struct {
ctx context.Context
rs []io.ReadCloser
}

var _ io.ReadCloser = (*MultiReader)(nil)

// NewMultiReader returns a new read closer that reads from all the readers.
// - This helps read small amount of body and concat read data with remains io.ReadCloser.
func NewMultiReader(rs ...io.ReadCloser) *MultiReader {
return &MultiReader{rs: rs}
}

func (r *MultiReader) SetContext(ctx context.Context) {
r.ctx = ctx
}

func (r *MultiReader) Read(p []byte) (int, error) {
nTotal, pTotal := 0, len(p)

index := 0
for {
if r.ctx != nil && r.ctx.Err() != nil {
return nTotal, r.ctx.Err()
}

if index >= len(r.rs) {
return nTotal, io.EOF
}

rr := r.rs[index]

n, err := rr.Read(p[nTotal:])
nTotal += n
pTotal -= n
if pTotal == 0 {
return nTotal, err
}

if err != nil {
if !errors.Is(err, io.EOF) {
return nTotal, err
}

index++
}
}
}

func (r *MultiReader) Close() error {
var err error
for _, rr := range r.rs {
if e := rr.Close(); e != nil {
err = errors.Join(err, e)
}
}

return err
}
100 changes: 100 additions & 0 deletions reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package klient

import (
"bytes"
"context"
"errors"
"io"
"testing"
)

func TestReader(t *testing.T) {
t.Run("concat reader", func(t *testing.T) {
data := []byte(`
Lorem ipsum dolor sit amet, consectetur adipiscing elit.
Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Lorem ipsum dolor sit amet, consectetur adipiscing elit.
Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
`)
readerData := bytes.NewReader(data)

// read part of the data
partData, err := io.ReadAll(io.LimitReader(readerData, 5))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// merge 2 readers together
r := NewMultiReader(io.NopCloser(bytes.NewReader(partData)), io.NopCloser(readerData))

// read the rest of the data
allData, err := io.ReadAll(r)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

if string(allData) != string(data) {
t.Errorf("expected %s, got %s", string(data), string(allData))
}

if err := r.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
})

t.Run("context cancel", func(t *testing.T) {
data := []byte("Hello, World!")
readerData := bytes.NewReader(data)

// read part of the data
partData, _ := io.ReadAll(io.LimitReader(readerData, 5))
// merge 2 readers together
r := NewMultiReader(io.NopCloser(bytes.NewReader(partData)), io.NopCloser(readerData))

ctx, cancel := context.WithCancel(context.Background())
r.SetContext(ctx)
cancel()

// read the rest of the data
_, err := io.ReadAll(r)
if err == nil {
t.Errorf("expected error, got nil")
}

if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled, got %v", err)
}
})

t.Run("small parts", func(t *testing.T) {
data1 := []byte("Hello")
data2 := []byte(", World!")

r := NewMultiReader(io.NopCloser(bytes.NewReader(data1)), io.NopCloser(bytes.NewReader(data2)))

p := make([]byte, 0, 50)
n, err := r.Read(p[len(p):cap(p)])
if !errors.Is(err, io.EOF) {
t.Errorf("unexpected error: %v", err)
}
p = p[:n]

if lenDatas := (len(data1) + len(data2)); n != lenDatas {
t.Errorf("expected %d, got %d", lenDatas, n)
}

if string(p) != "Hello, World!" {
t.Errorf("expected Hello, got %s", string(p))
}

if len(p) != 13 {
t.Errorf("expected 13, got %d", len(p))
}
})
}
13 changes: 3 additions & 10 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,16 @@ func ResponseFuncJSON(data interface{}) func(*http.Response) error {
}

// LimitedResponse not close body, retry library draining it.
// - Return limited response body
// - Ready all body and assign it back to resp.Body
func LimitedResponse(resp *http.Response) []byte {
if resp == nil {
return nil
}

v, _ := io.ReadAll(io.LimitReader(resp.Body, ResponseErrLimit))

bodyRemains, _ := io.ReadAll(resp.Body)
totalBody := append(v, bodyRemains...)

resp.Body = io.NopCloser(bytes.NewReader(totalBody))
resp.Body = NewMultiReader(io.NopCloser(bytes.NewReader(v)), resp.Body)

return v
}

// DrainBody reads the entire content of r and then closes the underlying io.ReadCloser.
func DrainBody(body io.ReadCloser) {
defer body.Close()
_, _ = io.Copy(io.Discard, io.LimitReader(body, ResponseErrLimit))
}

0 comments on commit 19e276b

Please sign in to comment.