Skip to content

Commit 5c0178f

Browse files
feat: add optional retryOptions to chat completion and chat completion stream requests (#3)
Co-authored-by: Donnie Adams <[email protected]>
1 parent 02b41e1 commit 5c0178f

File tree

3 files changed

+172
-39
lines changed

3 files changed

+172
-39
lines changed

chat.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ func (c *Client) CreateChatCompletion(
329329
ctx context.Context,
330330
request ChatCompletionRequest,
331331
headers map[string]string,
332+
retryOpts ...RetryOptions,
332333
) (response ChatCompletionResponse, err error) {
333334
if request.Stream {
334335
err = ErrChatCompletionStreamNotSupported
@@ -346,6 +347,6 @@ func (c *Client) CreateChatCompletion(
346347
req.Header.Add(k, v)
347348
}
348349

349-
err = c.sendRequest(req, &response)
350+
err = c.sendRequest(req, &response, retryOpts...)
350351
return
351352
}

chat_stream.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ func (c *Client) CreateChatCompletionStream(
4343
ctx context.Context,
4444
request ChatCompletionRequest,
4545
headers map[string]string,
46+
retryOpts ...RetryOptions,
4647
) (stream *ChatCompletionStream, err error) {
4748
urlSuffix := chatCompletionsSuffix
4849
request.Stream = true
@@ -55,7 +56,7 @@ func (c *Client) CreateChatCompletionStream(
5556
req.Header.Add(k, v)
5657
}
5758

58-
resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req)
59+
resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req, retryOpts...)
5960
if err != nil {
6061
return
6162
}

client.go

Lines changed: 168 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"math/rand"
1112
"net/http"
13+
"slices"
1214
"strings"
15+
"time"
1316
)
1417

1518
// Client is OpenAI GPT-3 API client.
@@ -86,67 +89,195 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ...
8689
return req, nil
8790
}
8891

89-
func (c *Client) sendRequest(req *http.Request, v Response) error {
92+
type RetryOptions struct {
93+
// Retries is the number of times to retry the request. 0 means no retries.
94+
Retries int
95+
96+
// RetryAboveCode is the status code above which the request should be retried.
97+
RetryAboveCode int
98+
RetryCodes []int
99+
}
100+
101+
func NewDefaultRetryOptions() RetryOptions {
102+
return RetryOptions{
103+
Retries: 0, // = one try, no retries
104+
RetryAboveCode: 1, // any - doesn't matter
105+
RetryCodes: nil, // none - doesn't matter
106+
}
107+
}
108+
109+
func (r *RetryOptions) complete(opts ...RetryOptions) {
110+
for _, opt := range opts {
111+
if opt.Retries > 0 {
112+
r.Retries = opt.Retries
113+
}
114+
if opt.RetryAboveCode > 0 {
115+
r.RetryAboveCode = opt.RetryAboveCode
116+
}
117+
for _, code := range opt.RetryCodes {
118+
if !slices.Contains(r.RetryCodes, code) {
119+
r.RetryCodes = append(r.RetryCodes, code)
120+
}
121+
}
122+
}
123+
}
124+
125+
func (r *RetryOptions) canRetry(statusCode int) bool {
126+
if r.RetryAboveCode > 0 && statusCode > r.RetryAboveCode {
127+
return true
128+
}
129+
return slices.Contains(r.RetryCodes, statusCode)
130+
}
131+
132+
func (c *Client) sendRequest(req *http.Request, v Response, retryOpts ...RetryOptions) error {
90133
req.Header.Set("Accept", "application/json")
91134

135+
// Default Options
136+
options := NewDefaultRetryOptions()
137+
options.complete(retryOpts...)
138+
92139
// Check whether Content-Type is already set, Upload Files API requires
93140
// Content-Type == multipart/form-data
94141
contentType := req.Header.Get("Content-Type")
95142
if contentType == "" {
96143
req.Header.Set("Content-Type", "application/json")
97144
}
98145

99-
res, err := c.config.HTTPClient.Do(req)
100-
if err != nil {
101-
return err
146+
const baseDelay = time.Millisecond * 200
147+
var (
148+
resp *http.Response
149+
err error
150+
failures []string
151+
)
152+
153+
// Save the original request body
154+
var bodyBytes []byte
155+
if req.Body != nil {
156+
bodyBytes, err = io.ReadAll(req.Body)
157+
_ = req.Body.Close()
158+
if err != nil {
159+
failures = append(failures, fmt.Sprintf("failed to read request body: %v", err))
160+
return fmt.Errorf("failed to read request body: %v; failures: %v", err, strings.Join(failures, "; "))
161+
}
102162
}
103163

104-
defer res.Body.Close()
164+
retryLoop:
165+
for i := 0; i <= options.Retries; i++ {
105166

106-
if isFailureStatusCode(res) {
107-
return c.handleErrorResp(res)
108-
}
167+
// Reset body to the original request body
168+
if bodyBytes != nil {
169+
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
170+
}
109171

110-
if v != nil {
111-
v.SetHeader(res.Header)
112-
}
172+
resp, err = c.config.HTTPClient.Do(req)
173+
if err == nil && !isFailureStatusCode(resp) {
174+
defer resp.Body.Close()
175+
if v != nil {
176+
v.SetHeader(resp.Header)
177+
}
178+
return decodeResponse(resp.Body, v)
179+
}
113180

114-
return decodeResponse(res.Body, v)
115-
}
181+
// handle connection errors
182+
if err != nil {
183+
failures = append(failures, fmt.Sprintf("#%d/%d failed to send request: %v", i+1, options.Retries+1, err))
184+
continue
185+
}
116186

117-
func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) {
118-
resp, err := c.config.HTTPClient.Do(req)
119-
if err != nil {
120-
return
121-
}
187+
// handle status codes
188+
failures = append(failures, fmt.Sprintf("#%d/%d error response received: %v", i+1, options.Retries+1, c.handleErrorResp(resp)))
189+
190+
// exit on non-retriable status codes
191+
if !options.canRetry(resp.StatusCode) {
192+
failures = append(failures, fmt.Sprintf("exiting due to non-retriable error in try #%d/%d: %v", i+1, options.Retries+1, resp.StatusCode))
193+
return fmt.Errorf("request failed on non-retriable error: %v", strings.Join(failures, "; "))
194+
}
122195

123-
if isFailureStatusCode(resp) {
124-
err = c.handleErrorResp(resp)
125-
return
196+
// exponential backoff
197+
delay := baseDelay * time.Duration(1<<i)
198+
jitter := time.Duration(rand.Int63n(int64(baseDelay)))
199+
select {
200+
case <-req.Context().Done():
201+
failures = append(failures, fmt.Sprintf("exiting due to canceled context after try #%d/%d: %v", i+1, options.Retries+1, req.Context().Err()))
202+
break retryLoop
203+
case <-time.After(delay + jitter):
204+
}
126205
}
127-
return resp.Body, nil
206+
return fmt.Errorf("request failed %d times: %v", options.Retries+1, strings.Join(failures, "; "))
128207
}
129208

130-
func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) {
209+
func sendRequestStream[T streamable](client *Client, req *http.Request, retryOpts ...RetryOptions) (*streamReader[T], error) {
131210
req.Header.Set("Content-Type", "application/json")
132211
req.Header.Set("Accept", "text/event-stream")
133212
req.Header.Set("Cache-Control", "no-cache")
134213
req.Header.Set("Connection", "keep-alive")
135214

136-
resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
137-
if err != nil {
138-
return new(streamReader[T]), err
139-
}
140-
if isFailureStatusCode(resp) {
141-
return new(streamReader[T]), client.handleErrorResp(resp)
142-
}
143-
return &streamReader[T]{
144-
emptyMessagesLimit: client.config.EmptyMessagesLimit,
145-
reader: bufio.NewReader(resp.Body),
146-
response: resp,
147-
errBuffer: &bytes.Buffer{},
148-
httpHeader: httpHeader(resp.Header),
149-
}, nil
215+
// Default Retry Options
216+
options := NewDefaultRetryOptions()
217+
options.complete(retryOpts...)
218+
219+
const baseDelay = time.Millisecond * 200
220+
var (
221+
err error
222+
failures []string
223+
)
224+
225+
// Save the original request body
226+
var bodyBytes []byte
227+
if req.Body != nil {
228+
bodyBytes, err = io.ReadAll(req.Body)
229+
_ = req.Body.Close()
230+
if err != nil {
231+
failures = append(failures, fmt.Sprintf("failed to read request body: %v", err))
232+
return nil, fmt.Errorf("failed to read request body: %v; failures: %v", err, strings.Join(failures, "; "))
233+
}
234+
}
235+
236+
streamRetryLoop:
237+
for i := 0; i <= options.Retries; i++ {
238+
239+
// Reset body to the original request body
240+
if bodyBytes != nil {
241+
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
242+
}
243+
244+
resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close()
245+
if err == nil && !isFailureStatusCode(resp) {
246+
// we're good!
247+
return &streamReader[T]{
248+
emptyMessagesLimit: client.config.EmptyMessagesLimit,
249+
reader: bufio.NewReader(resp.Body),
250+
response: resp,
251+
errBuffer: &bytes.Buffer{},
252+
httpHeader: httpHeader(resp.Header),
253+
}, nil
254+
}
255+
256+
if err != nil {
257+
failures = append(failures, fmt.Sprintf("#%d/%d failed to send request: %v", i+1, options.Retries+1, err))
258+
continue
259+
}
260+
261+
// handle status codes
262+
failures = append(failures, fmt.Sprintf("#%d/%d error response received: %v", i+1, options.Retries+1, client.handleErrorResp(resp)))
263+
264+
// exit on non-retriable status codes
265+
if !options.canRetry(resp.StatusCode) {
266+
failures = append(failures, fmt.Sprintf("exiting due to non-retriable error in try #%d/%d: %v", i+1, options.Retries+1, resp.StatusCode))
267+
return nil, fmt.Errorf("request failed on non-retriable error: %v", strings.Join(failures, "; "))
268+
}
269+
270+
// exponential backoff
271+
delay := baseDelay * time.Duration(1<<i)
272+
jitter := time.Duration(rand.Int63n(int64(baseDelay)))
273+
select {
274+
case <-req.Context().Done():
275+
failures = append(failures, fmt.Sprintf("exiting due to canceled context after try #%d/%d: %v", i+1, options.Retries+1, req.Context().Err()))
276+
break streamRetryLoop
277+
case <-time.After(delay + jitter):
278+
}
279+
}
280+
return nil, fmt.Errorf("request failed %d times: %v", options.Retries+1, strings.Join(failures, "; "))
150281
}
151282

152283
func (c *Client) setCommonHeaders(req *http.Request) {

0 commit comments

Comments
 (0)