@@ -8,8 +8,11 @@ import (
8
8
"errors"
9
9
"fmt"
10
10
"io"
11
+ "math/rand"
11
12
"net/http"
13
+ "slices"
12
14
"strings"
15
+ "time"
13
16
)
14
17
15
18
// Client is OpenAI GPT-3 API client.
@@ -86,67 +89,195 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ...
86
89
return req , nil
87
90
}
88
91
89
- func (c * Client ) sendRequest (req * http.Request , v Response ) error {
92
+ type RetryOptions struct {
93
+ // Retries is the number of times to retry the request. 0 means no retries.
94
+ Retries int
95
+
96
+ // RetryAboveCode is the status code above which the request should be retried.
97
+ RetryAboveCode int
98
+ RetryCodes []int
99
+ }
100
+
101
+ func NewDefaultRetryOptions () RetryOptions {
102
+ return RetryOptions {
103
+ Retries : 0 , // = one try, no retries
104
+ RetryAboveCode : 1 , // any - doesn't matter
105
+ RetryCodes : nil , // none - doesn't matter
106
+ }
107
+ }
108
+
109
+ func (r * RetryOptions ) complete (opts ... RetryOptions ) {
110
+ for _ , opt := range opts {
111
+ if opt .Retries > 0 {
112
+ r .Retries = opt .Retries
113
+ }
114
+ if opt .RetryAboveCode > 0 {
115
+ r .RetryAboveCode = opt .RetryAboveCode
116
+ }
117
+ for _ , code := range opt .RetryCodes {
118
+ if ! slices .Contains (r .RetryCodes , code ) {
119
+ r .RetryCodes = append (r .RetryCodes , code )
120
+ }
121
+ }
122
+ }
123
+ }
124
+
125
+ func (r * RetryOptions ) canRetry (statusCode int ) bool {
126
+ if r .RetryAboveCode > 0 && statusCode > r .RetryAboveCode {
127
+ return true
128
+ }
129
+ return slices .Contains (r .RetryCodes , statusCode )
130
+ }
131
+
132
+ func (c * Client ) sendRequest (req * http.Request , v Response , retryOpts ... RetryOptions ) error {
90
133
req .Header .Set ("Accept" , "application/json" )
91
134
135
+ // Default Options
136
+ options := NewDefaultRetryOptions ()
137
+ options .complete (retryOpts ... )
138
+
92
139
// Check whether Content-Type is already set, Upload Files API requires
93
140
// Content-Type == multipart/form-data
94
141
contentType := req .Header .Get ("Content-Type" )
95
142
if contentType == "" {
96
143
req .Header .Set ("Content-Type" , "application/json" )
97
144
}
98
145
99
- res , err := c .config .HTTPClient .Do (req )
100
- if err != nil {
101
- return err
146
+ const baseDelay = time .Millisecond * 200
147
+ var (
148
+ resp * http.Response
149
+ err error
150
+ failures []string
151
+ )
152
+
153
+ // Save the original request body
154
+ var bodyBytes []byte
155
+ if req .Body != nil {
156
+ bodyBytes , err = io .ReadAll (req .Body )
157
+ _ = req .Body .Close ()
158
+ if err != nil {
159
+ failures = append (failures , fmt .Sprintf ("failed to read request body: %v" , err ))
160
+ return fmt .Errorf ("failed to read request body: %v; failures: %v" , err , strings .Join (failures , "; " ))
161
+ }
102
162
}
103
163
104
- defer res .Body .Close ()
164
+ retryLoop:
165
+ for i := 0 ; i <= options .Retries ; i ++ {
105
166
106
- if isFailureStatusCode (res ) {
107
- return c .handleErrorResp (res )
108
- }
167
+ // Reset body to the original request body
168
+ if bodyBytes != nil {
169
+ req .Body = io .NopCloser (bytes .NewBuffer (bodyBytes ))
170
+ }
109
171
110
- if v != nil {
111
- v .SetHeader (res .Header )
112
- }
172
+ resp , err = c .config .HTTPClient .Do (req )
173
+ if err == nil && ! isFailureStatusCode (resp ) {
174
+ defer resp .Body .Close ()
175
+ if v != nil {
176
+ v .SetHeader (resp .Header )
177
+ }
178
+ return decodeResponse (resp .Body , v )
179
+ }
113
180
114
- return decodeResponse (res .Body , v )
115
- }
181
+ // handle connection errors
182
+ if err != nil {
183
+ failures = append (failures , fmt .Sprintf ("#%d/%d failed to send request: %v" , i + 1 , options .Retries + 1 , err ))
184
+ continue
185
+ }
116
186
117
- func (c * Client ) sendRequestRaw (req * http.Request ) (body io.ReadCloser , err error ) {
118
- resp , err := c .config .HTTPClient .Do (req )
119
- if err != nil {
120
- return
121
- }
187
+ // handle status codes
188
+ failures = append (failures , fmt .Sprintf ("#%d/%d error response received: %v" , i + 1 , options .Retries + 1 , c .handleErrorResp (resp )))
189
+
190
+ // exit on non-retriable status codes
191
+ if ! options .canRetry (resp .StatusCode ) {
192
+ failures = append (failures , fmt .Sprintf ("exiting due to non-retriable error in try #%d/%d: %v" , i + 1 , options .Retries + 1 , resp .StatusCode ))
193
+ return fmt .Errorf ("request failed on non-retriable error: %v" , strings .Join (failures , "; " ))
194
+ }
122
195
123
- if isFailureStatusCode (resp ) {
124
- err = c .handleErrorResp (resp )
125
- return
196
+ // exponential backoff
197
+ delay := baseDelay * time .Duration (1 << i )
198
+ jitter := time .Duration (rand .Int63n (int64 (baseDelay )))
199
+ select {
200
+ case <- req .Context ().Done ():
201
+ failures = append (failures , fmt .Sprintf ("exiting due to canceled context after try #%d/%d: %v" , i + 1 , options .Retries + 1 , req .Context ().Err ()))
202
+ break retryLoop
203
+ case <- time .After (delay + jitter ):
204
+ }
126
205
}
127
- return resp . Body , nil
206
+ return fmt . Errorf ( "request failed %d times: %v" , options . Retries + 1 , strings . Join ( failures , "; " ))
128
207
}
129
208
130
- func sendRequestStream [T streamable ](client * Client , req * http.Request ) (* streamReader [T ], error ) {
209
+ func sendRequestStream [T streamable ](client * Client , req * http.Request , retryOpts ... RetryOptions ) (* streamReader [T ], error ) {
131
210
req .Header .Set ("Content-Type" , "application/json" )
132
211
req .Header .Set ("Accept" , "text/event-stream" )
133
212
req .Header .Set ("Cache-Control" , "no-cache" )
134
213
req .Header .Set ("Connection" , "keep-alive" )
135
214
136
- resp , err := client .config .HTTPClient .Do (req ) //nolint:bodyclose // body is closed in stream.Close()
137
- if err != nil {
138
- return new (streamReader [T ]), err
139
- }
140
- if isFailureStatusCode (resp ) {
141
- return new (streamReader [T ]), client .handleErrorResp (resp )
142
- }
143
- return & streamReader [T ]{
144
- emptyMessagesLimit : client .config .EmptyMessagesLimit ,
145
- reader : bufio .NewReader (resp .Body ),
146
- response : resp ,
147
- errBuffer : & bytes.Buffer {},
148
- httpHeader : httpHeader (resp .Header ),
149
- }, nil
215
+ // Default Retry Options
216
+ options := NewDefaultRetryOptions ()
217
+ options .complete (retryOpts ... )
218
+
219
+ const baseDelay = time .Millisecond * 200
220
+ var (
221
+ err error
222
+ failures []string
223
+ )
224
+
225
+ // Save the original request body
226
+ var bodyBytes []byte
227
+ if req .Body != nil {
228
+ bodyBytes , err = io .ReadAll (req .Body )
229
+ _ = req .Body .Close ()
230
+ if err != nil {
231
+ failures = append (failures , fmt .Sprintf ("failed to read request body: %v" , err ))
232
+ return nil , fmt .Errorf ("failed to read request body: %v; failures: %v" , err , strings .Join (failures , "; " ))
233
+ }
234
+ }
235
+
236
+ streamRetryLoop:
237
+ for i := 0 ; i <= options .Retries ; i ++ {
238
+
239
+ // Reset body to the original request body
240
+ if bodyBytes != nil {
241
+ req .Body = io .NopCloser (bytes .NewBuffer (bodyBytes ))
242
+ }
243
+
244
+ resp , err := client .config .HTTPClient .Do (req ) //nolint:bodyclose // body is closed in stream.Close()
245
+ if err == nil && ! isFailureStatusCode (resp ) {
246
+ // we're good!
247
+ return & streamReader [T ]{
248
+ emptyMessagesLimit : client .config .EmptyMessagesLimit ,
249
+ reader : bufio .NewReader (resp .Body ),
250
+ response : resp ,
251
+ errBuffer : & bytes.Buffer {},
252
+ httpHeader : httpHeader (resp .Header ),
253
+ }, nil
254
+ }
255
+
256
+ if err != nil {
257
+ failures = append (failures , fmt .Sprintf ("#%d/%d failed to send request: %v" , i + 1 , options .Retries + 1 , err ))
258
+ continue
259
+ }
260
+
261
+ // handle status codes
262
+ failures = append (failures , fmt .Sprintf ("#%d/%d error response received: %v" , i + 1 , options .Retries + 1 , client .handleErrorResp (resp )))
263
+
264
+ // exit on non-retriable status codes
265
+ if ! options .canRetry (resp .StatusCode ) {
266
+ failures = append (failures , fmt .Sprintf ("exiting due to non-retriable error in try #%d/%d: %v" , i + 1 , options .Retries + 1 , resp .StatusCode ))
267
+ return nil , fmt .Errorf ("request failed on non-retriable error: %v" , strings .Join (failures , "; " ))
268
+ }
269
+
270
+ // exponential backoff
271
+ delay := baseDelay * time .Duration (1 << i )
272
+ jitter := time .Duration (rand .Int63n (int64 (baseDelay )))
273
+ select {
274
+ case <- req .Context ().Done ():
275
+ failures = append (failures , fmt .Sprintf ("exiting due to canceled context after try #%d/%d: %v" , i + 1 , options .Retries + 1 , req .Context ().Err ()))
276
+ break streamRetryLoop
277
+ case <- time .After (delay + jitter ):
278
+ }
279
+ }
280
+ return nil , fmt .Errorf ("request failed %d times: %v" , options .Retries + 1 , strings .Join (failures , "; " ))
150
281
}
151
282
152
283
func (c * Client ) setCommonHeaders (req * http.Request ) {
0 commit comments