@@ -401,11 +401,24 @@ type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Resp
401
401
// PrepareRetry is called before retry operation. It can be used for example to re-sign the request
402
402
type PrepareRetry func (req * http.Request ) error
403
403
404
+ type HTTPClient interface {
405
+ // Do performs an HTTP request and returns an HTTP response.
406
+ Do (* http.Request ) (* http.Response , error )
407
+ // Done is called when the client is no longer needed.
408
+ Done ()
409
+ }
410
+
411
+ type HTTPClientFactory interface {
412
+ // New returns an HTTP client to use for a request, including retries.
413
+ New () HTTPClient
414
+ }
415
+
404
416
// Client is used to make HTTP requests. It adds additional functionality
405
417
// like automatic retries to tolerate minor outages.
406
418
type Client struct {
407
- HTTPClient * http.Client // Internal HTTP client.
408
- Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
419
+ HTTPClient * http.Client // Internal HTTP client. This field is used if set, otherwise HTTPClientFactory is used.
420
+ HTTPClientFactory HTTPClientFactory
421
+ Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
409
422
410
423
RetryWaitMin time.Duration // Minimum time to wait
411
424
RetryWaitMax time.Duration // Maximum time to wait
@@ -433,19 +446,18 @@ type Client struct {
433
446
PrepareRetry PrepareRetry
434
447
435
448
loggerInit sync.Once
436
- clientInit sync.Once
437
449
}
438
450
439
451
// NewClient creates a new Client with default settings.
440
452
func NewClient () * Client {
441
453
return & Client {
442
- HTTPClient : cleanhttp . DefaultPooledClient () ,
443
- Logger : defaultLogger ,
444
- RetryWaitMin : defaultRetryWaitMin ,
445
- RetryWaitMax : defaultRetryWaitMax ,
446
- RetryMax : defaultRetryMax ,
447
- CheckRetry : DefaultRetryPolicy ,
448
- Backoff : DefaultBackoff ,
454
+ HTTPClientFactory : & CleanPooledClientFactory {} ,
455
+ Logger : defaultLogger ,
456
+ RetryWaitMin : defaultRetryWaitMin ,
457
+ RetryWaitMax : defaultRetryWaitMax ,
458
+ RetryMax : defaultRetryMax ,
459
+ CheckRetry : DefaultRetryPolicy ,
460
+ Backoff : DefaultBackoff ,
449
461
}
450
462
}
451
463
@@ -647,12 +659,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo
647
659
648
660
// Do wraps calling an HTTP method with retries.
649
661
func (c * Client ) Do (req * Request ) (* http.Response , error ) {
650
- c .clientInit .Do (func () {
651
- if c .HTTPClient == nil {
652
- c .HTTPClient = cleanhttp .DefaultPooledClient ()
653
- }
654
- })
655
-
656
662
logger := c .logger ()
657
663
658
664
if logger != nil {
@@ -664,6 +670,9 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
664
670
}
665
671
}
666
672
673
+ httpClient := c .getHTTPClient ()
674
+ defer httpClient .Done ()
675
+
667
676
var resp * http.Response
668
677
var attempt int
669
678
var shouldRetry bool
@@ -677,7 +686,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
677
686
if req .body != nil {
678
687
body , err := req .body ()
679
688
if err != nil {
680
- c .HTTPClient .CloseIdleConnections ()
681
689
return resp , err
682
690
}
683
691
if c , ok := body .(io.ReadCloser ); ok {
@@ -699,7 +707,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
699
707
}
700
708
701
709
// Attempt the request
702
- resp , doErr = c .HTTPClient .Do (req .Request )
710
+
711
+ resp , doErr = httpClient .Do (req .Request )
703
712
704
713
// Check if we should continue with retries.
705
714
shouldRetry , checkErr = c .CheckRetry (req .Context (), resp , doErr )
@@ -768,7 +777,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
768
777
select {
769
778
case <- req .Context ().Done ():
770
779
timer .Stop ()
771
- c .HTTPClient .CloseIdleConnections ()
772
780
return nil , req .Context ().Err ()
773
781
case <- timer .C :
774
782
}
@@ -791,8 +799,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
791
799
return resp , nil
792
800
}
793
801
794
- defer c .HTTPClient .CloseIdleConnections ()
795
-
796
802
var err error
797
803
if prepareErr != nil {
798
804
err = prepareErr
@@ -841,6 +847,19 @@ func (c *Client) drainBody(body io.ReadCloser) {
841
847
}
842
848
}
843
849
850
+ func (c * Client ) getHTTPClient () HTTPClient {
851
+ if c .HTTPClient != nil {
852
+ return & idleConnectionsClosingClient {
853
+ httpClient : c .HTTPClient ,
854
+ }
855
+ }
856
+ clientFactory := c .HTTPClientFactory
857
+ if clientFactory == nil {
858
+ clientFactory = & CleanPooledClientFactory {}
859
+ }
860
+ return clientFactory .New ()
861
+ }
862
+
844
863
// Get is a shortcut for doing a GET request without making a new client.
845
864
func Get (url string ) (* http.Response , error ) {
846
865
return defaultClient .Get (url )
@@ -917,3 +936,29 @@ func redactURL(u *url.URL) string {
917
936
}
918
937
return ru .String ()
919
938
}
939
+
940
+ var (
941
+ _ HTTPClientFactory = & CleanPooledClientFactory {}
942
+ _ HTTPClient = & idleConnectionsClosingClient {}
943
+ )
944
+
945
+ type CleanPooledClientFactory struct {
946
+ }
947
+
948
+ func (f * CleanPooledClientFactory ) New () HTTPClient {
949
+ return & idleConnectionsClosingClient {
950
+ httpClient : cleanhttp .DefaultPooledClient (),
951
+ }
952
+ }
953
+
954
+ type idleConnectionsClosingClient struct {
955
+ httpClient * http.Client
956
+ }
957
+
958
+ func (c * idleConnectionsClosingClient ) Do (req * http.Request ) (* http.Response , error ) {
959
+ return c .httpClient .Do (req )
960
+ }
961
+
962
+ func (c * idleConnectionsClosingClient ) Done () {
963
+ c .httpClient .CloseIdleConnections ()
964
+ }
0 commit comments