diff --git a/okta/requestExecutor.go b/okta/requestExecutor.go index 230566b0a..dbcfab793 100644 --- a/okta/requestExecutor.go +++ b/okta/requestExecutor.go @@ -456,7 +456,8 @@ func (re *RequestExecutor) Do(ctx context.Context, req *http.Request, v interfac re.freshCache = false } if !inCache { - resp, err := re.doWithRetries(ctx, req) + resp, done, err := re.doWithRetries(ctx, req) + defer done() if err != nil { return nil, err } @@ -492,12 +493,13 @@ func (o *oktaBackoff) Context() context.Context { return o.ctx } -func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) (*http.Response, error) { +func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) (*http.Response, func(), error) { var bodyReader func() io.ReadCloser + done := func() {} if req.Body != nil { buf, err := io.ReadAll(req.Body) if err != nil { - return nil, err + return nil, done, err } bodyReader = func() io.ReadCloser { return io.NopCloser(bytes.NewReader(buf)) @@ -508,9 +510,7 @@ func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) err error ) if re.config.Okta.Client.RequestTimeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Second*time.Duration(re.config.Okta.Client.RequestTimeout)) - defer cancel() + ctx, done = context.WithTimeout(ctx, time.Second*time.Duration(re.config.Okta.Client.RequestTimeout)) } bOff := &oktaBackoff{ ctx: ctx, @@ -549,7 +549,7 @@ func (re *RequestExecutor) doWithRetries(ctx context.Context, req *http.Request) return errors.New("too many requests") } err = backoff.Retry(operation, bOff) - return resp, err + return resp, done, err } func tooManyRequests(resp *http.Response) bool { @@ -649,7 +649,10 @@ func CheckResponseForError(resp *http.Response) error { } } } - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } copyBodyBytes := make([]byte, len(bodyBytes)) copy(copyBodyBytes, bodyBytes) _ = resp.Body.Close() @@ -668,7 +671,10 @@ func buildResponse(resp *http.Response, re *RequestExecutor, v interface{}) (*Re if err != nil { return response, err } - bodyBytes, _ := io.ReadAll(resp.Body) + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } copyBodyBytes := make([]byte, len(bodyBytes)) copy(copyBodyBytes, bodyBytes) _ = resp.Body.Close() // close it to avoid memory leaks diff --git a/tests/unit/request_executor_test.go b/tests/unit/request_executor_test.go new file mode 100644 index 000000000..1450115e8 --- /dev/null +++ b/tests/unit/request_executor_test.go @@ -0,0 +1,69 @@ +package unit + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/okta/okta-sdk-golang/v2/okta" + "github.com/okta/okta-sdk-golang/v2/tests" + "github.com/stretchr/testify/assert" +) + +// readerFun makes it easier to implement an inline reader as a closure function. +type readerFun func(p []byte) (n int, err error) + +// Read, part of io.Reader interface. +func (r readerFun) Read(p []byte) (n int, err error) { return r(p) } + +// slowTransport provides a dummy http-like transport serving fixed content, but slowly. +type slowTransport struct{} + +// RoundTrip, part of http.Transport interface. This servers 42 as a JSON response, but slowly. +// In particular, we serve the response immediately, but getting the body takes some milliseconds. +func (t slowTransport) RoundTrip(req *http.Request) (*http.Response, error) { + realBody := strings.NewReader("42") + // This body takes 1 millisecond to read. It also needs a valid context for the whole duration. + slowBody := func(p []byte) (n int, err error) { + select { + case <-req.Context().Done(): + return 0, req.Context().Err() + case <-time.After(1 * time.Millisecond): + return realBody.Read(p) + } + } + + rsp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(readerFun(slowBody)), + Header: http.Header{}, + Request: req, + } + rsp.Header.Set("Content-Type", "application/json") + return rsp, nil +} + +// TestExecuteRequest tests that the request executor can handle a slow response. +// In particular, we want to make sure that the context is properly passed through +// and not canceled too early. +func TestExecuteRequest(t *testing.T) { + cfg := []okta.ConfigSetter{ + okta.WithOrgUrl("https://fakeurl"), // This is ignored, but required for validator. + okta.WithToken("foo"), // ditto. + okta.WithHttpClientPtr(&http.Client{Transport: slowTransport{}}), // Use our more realistic transport. + okta.WithRequestTimeout(10), // The context issues are gated with actually having a timeout. + } + ctx, cl, err := tests.NewClient(context.Background(), cfg...) + assert.NoError(t, err, "Basic client errors") + req, err := http.NewRequest("GET", "https://fakeurl", http.NoBody) + assert.NoError(t, err, "Request building") + var out int + rs, err := cl.GetRequestExecutor().Do(ctx, req, &out) + assert.NoError(t, err, "Request execution") + if rs.StatusCode != 200 || out != 42 { + t.Errorf("Got val=%d status=%d, want 42 status=200", out, rs.StatusCode) + } +}