diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c277b8d --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +.DEFAULT_GOAL := help + +.PHONY: lint +lint: ## Lint Go files + @GOPATH="$(shell dirname $(PWD))" golangci-lint run ./... + +.PHONY: test +test: ## Run unit tests + @go test -v -race ./... + +.PHONY: coverage +coverage: ## Run unit tests with coverage + @go test -v -race -cover -coverpkg=./... -coverprofile=coverage.out -covermode=atomic ./... + @go tool cover -func=coverage.out + +.PHONY: help +help: ## Display this help screen + @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/drain.go b/drain.go new file mode 100644 index 0000000..218c46f --- /dev/null +++ b/drain.go @@ -0,0 +1,44 @@ +package klient + +import "io" + +type optionDrain struct { + Limit int64 +} + +func newOptionDrain(opts []OptionDrain) *optionDrain { + o := new(optionDrain) + for _, opt := range opts { + opt(o) + } + + if o.Limit == 0 { + o.Limit = ResponseErrLimit + } + + return o +} + +type OptionDrain func(*optionDrain) + +// WithDrainLimit sets the limit of the content to be read. +// If the limit is less than 0, it will read all the content. +func WithDrainLimit(limit int64) OptionDrain { + return func(o *optionDrain) { + o.Limit = limit + } +} + +// DrainBody reads the limited content of r and then closes the underlying io.ReadCloser. +func DrainBody(body io.ReadCloser, opts ...OptionDrain) { + o := newOptionDrain(opts) + + defer body.Close() + if o.Limit < 0 { + _, _ = io.Copy(io.Discard, body) + + return + } + + _, _ = io.Copy(io.Discard, io.LimitReader(body, o.Limit)) +} diff --git a/error.go b/error.go index e893171..1cac1d5 100644 --- a/error.go +++ b/error.go @@ -3,6 +3,7 @@ package klient import ( "errors" "fmt" + "io" "net/http" ) @@ -29,7 +30,7 @@ func (e *ResponseError) Error() string { // ErrResponse returns an error with the limited response body. func ErrResponse(resp *http.Response) error { - partialBody := LimitedResponse(resp) + partialBody, _ := io.ReadAll(io.LimitReader(resp.Body, ResponseErrLimit)) return &ResponseError{ StatusCode: resp.StatusCode, diff --git a/reader.go b/reader.go new file mode 100644 index 0000000..067f6e4 --- /dev/null +++ b/reader.go @@ -0,0 +1,67 @@ +package klient + +import ( + "context" + "errors" + "io" +) + +type MultiReader struct { + ctx context.Context + rs []io.ReadCloser +} + +var _ io.ReadCloser = (*MultiReader)(nil) + +// NewMultiReader returns a new read closer that reads from all the readers. +// - This helps read small amount of body and concat read data with remains io.ReadCloser. +func NewMultiReader(rs ...io.ReadCloser) *MultiReader { + return &MultiReader{rs: rs} +} + +func (r *MultiReader) SetContext(ctx context.Context) { + r.ctx = ctx +} + +func (r *MultiReader) Read(p []byte) (int, error) { + nTotal, pTotal := 0, len(p) + + index := 0 + for { + if r.ctx != nil && r.ctx.Err() != nil { + return nTotal, r.ctx.Err() + } + + if index >= len(r.rs) { + return nTotal, io.EOF + } + + rr := r.rs[index] + + n, err := rr.Read(p[nTotal:]) + nTotal += n + pTotal -= n + if pTotal == 0 { + return nTotal, err + } + + if err != nil { + if !errors.Is(err, io.EOF) { + return nTotal, err + } + + index++ + } + } +} + +func (r *MultiReader) Close() error { + var err error + for _, rr := range r.rs { + if e := rr.Close(); e != nil { + err = errors.Join(err, e) + } + } + + return err +} diff --git a/reader_test.go b/reader_test.go new file mode 100644 index 0000000..8c25ccb --- /dev/null +++ b/reader_test.go @@ -0,0 +1,100 @@ +package klient + +import ( + "bytes" + "context" + "errors" + "io" + "testing" +) + +func TestReader(t *testing.T) { + t.Run("concat reader", func(t *testing.T) { + data := []byte(` + Lorem ipsum dolor sit amet, consectetur adipiscing elit. + Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. + + Lorem ipsum dolor sit amet, consectetur adipiscing elit. + Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. + `) + readerData := bytes.NewReader(data) + + // read part of the data + partData, err := io.ReadAll(io.LimitReader(readerData, 5)) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // merge 2 readers together + r := NewMultiReader(io.NopCloser(bytes.NewReader(partData)), io.NopCloser(readerData)) + + // read the rest of the data + allData, err := io.ReadAll(r) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if string(allData) != string(data) { + t.Errorf("expected %s, got %s", string(data), string(allData)) + } + + if err := r.Close(); err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + + t.Run("context cancel", func(t *testing.T) { + data := []byte("Hello, World!") + readerData := bytes.NewReader(data) + + // read part of the data + partData, _ := io.ReadAll(io.LimitReader(readerData, 5)) + // merge 2 readers together + r := NewMultiReader(io.NopCloser(bytes.NewReader(partData)), io.NopCloser(readerData)) + + ctx, cancel := context.WithCancel(context.Background()) + r.SetContext(ctx) + cancel() + + // read the rest of the data + _, err := io.ReadAll(r) + if err == nil { + t.Errorf("expected error, got nil") + } + + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled, got %v", err) + } + }) + + t.Run("small parts", func(t *testing.T) { + data1 := []byte("Hello") + data2 := []byte(", World!") + + r := NewMultiReader(io.NopCloser(bytes.NewReader(data1)), io.NopCloser(bytes.NewReader(data2))) + + p := make([]byte, 0, 50) + n, err := r.Read(p[len(p):cap(p)]) + if !errors.Is(err, io.EOF) { + t.Errorf("unexpected error: %v", err) + } + p = p[:n] + + if lenDatas := (len(data1) + len(data2)); n != lenDatas { + t.Errorf("expected %d, got %d", lenDatas, n) + } + + if string(p) != "Hello, World!" { + t.Errorf("expected Hello, got %s", string(p)) + } + + if len(p) != 13 { + t.Errorf("expected 13, got %d", len(p)) + } + }) +} diff --git a/response.go b/response.go index edd2470..46c2d7f 100644 --- a/response.go +++ b/response.go @@ -38,6 +38,8 @@ func ResponseFuncJSON(data interface{}) func(*http.Response) error { } // LimitedResponse not close body, retry library draining it. +// - Return limited response body +// - Ready all body and assign it back to resp.Body func LimitedResponse(resp *http.Response) []byte { if resp == nil { return nil @@ -45,16 +47,7 @@ func LimitedResponse(resp *http.Response) []byte { v, _ := io.ReadAll(io.LimitReader(resp.Body, ResponseErrLimit)) - bodyRemains, _ := io.ReadAll(resp.Body) - totalBody := append(v, bodyRemains...) - - resp.Body = io.NopCloser(bytes.NewReader(totalBody)) + resp.Body = NewMultiReader(io.NopCloser(bytes.NewReader(v)), resp.Body) return v } - -// DrainBody reads the entire content of r and then closes the underlying io.ReadCloser. -func DrainBody(body io.ReadCloser) { - defer body.Close() - _, _ = io.Copy(io.Discard, io.LimitReader(body, ResponseErrLimit)) -}