Skip to content

Handle SSL errors with retries #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 51 additions & 33 deletions googleapiclient/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import mimetypes
import os
import random
import ssl
import sys
import time
import uuid
Expand All @@ -61,6 +62,46 @@
MAX_URI_LENGTH = 2048


def _retry_request(http, num_retries, req_type, sleep, rand, uri, method, *args,
**kwargs):
"""Retries an HTTP request multiple times while handling errors.

If after all retries the request still fails, last error is either returned as
return value (for HTTP 5xx errors) or thrown (for ssl.SSLError).

Args:
http: Http object to be used to execute request.
num_retries: Maximum number of retries.
req_type: Type of the request (used for logging retries).
sleep, rand: Functions to sleep for random time between retries.
uri: URI to be requested.
method: HTTP method to be used.
args, kwargs: Additional arguments passed to http.request.

Returns:
resp, content - Response from the http request (may be HTTP 5xx).
"""
resp = None
for retry_num in range(num_retries + 1):
if retry_num > 0:
sleep(rand() * 2**retry_num)
logging.warning(
'Retry #%d for %s: %s %s%s' % (retry_num, req_type, method, uri,
', following status: %d' % resp.status if resp else ''))

try:
resp, content = http.request(uri, method, *args, **kwargs)
except ssl.SSLError:
if retry_num == num_retries:
raise
else:
continue
if resp.status < 500:
break

return resp, content


class MediaUploadProgress(object):
"""Status of a resumable upload."""

Expand Down Expand Up @@ -546,16 +587,9 @@ def next_chunk(self, num_retries=0):
}
http = self._request.http

for retry_num in range(num_retries + 1):
if retry_num > 0:
self._sleep(self._rand() * 2**retry_num)
logging.warning(
'Retry #%d for media download: GET %s, following status: %d'
% (retry_num, self._uri, resp.status))

resp, content = http.request(self._uri, headers=headers)
if resp.status < 500:
break
resp, content = _retry_request(
http, num_retries, 'media download', self._sleep, self._rand, self._uri,
'GET', headers=headers)

if resp.status in [200, 206]:
if 'content-location' in resp and resp['content-location'] != self._uri:
Expand Down Expand Up @@ -654,7 +688,7 @@ def __init__(self, http, postproc, uri,

# Pull the multipart boundary out of the content-type header.
major, minor, params = mimeparse.parse_mime_type(
headers.get('content-type', 'application/json'))
self.headers.get('content-type', 'application/json'))

# The size of the non-media part of the request.
self.body_size = len(self.body or '')
Expand Down Expand Up @@ -716,16 +750,9 @@ def execute(self, http=None, num_retries=0):
self.headers['content-length'] = str(len(self.body))

# Handle retries for server-side errors.
for retry_num in range(num_retries + 1):
if retry_num > 0:
self._sleep(self._rand() * 2**retry_num)
logging.warning('Retry #%d for request: %s %s, following status: %d'
% (retry_num, self.method, self.uri, resp.status))

resp, content = http.request(str(self.uri), method=str(self.method),
body=self.body, headers=self.headers)
if resp.status < 500:
break
resp, content = _retry_request(
http, num_retries, 'request', self._sleep, self._rand, str(self.uri),
method=str(self.method), body=self.body, headers=self.headers)

for callback in self.response_callbacks:
callback(resp)
Expand Down Expand Up @@ -799,18 +826,9 @@ def next_chunk(self, http=None, num_retries=0):
start_headers['X-Upload-Content-Length'] = size
start_headers['content-length'] = str(self.body_size)

for retry_num in range(num_retries + 1):
if retry_num > 0:
self._sleep(self._rand() * 2**retry_num)
logging.warning(
'Retry #%d for resumable URI request: %s %s, following status: %d'
% (retry_num, self.method, self.uri, resp.status))

resp, content = http.request(self.uri, method=self.method,
body=self.body,
headers=start_headers)
if resp.status < 500:
break
resp, content = _retry_request(
http, num_retries, 'resumable URI request', self._sleep, self._rand,
self.uri, method=self.method, body=self.body, headers=start_headers)

if resp.status == 200 and 'location' in resp:
self.resumable_uri = resp['location']
Expand Down
59 changes: 59 additions & 0 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import os
import unittest2 as unittest
import random
import ssl
import time

from googleapiclient.discovery import build
Expand Down Expand Up @@ -101,6 +102,20 @@ def apply(self, headers):
headers['authorization'] = self._bearer_token + ' ' + str(self._refreshed)


class HttpMockWithSSLErrors(object):
def __init__(self, num_errors, success_json, success_data):
self.num_errors = num_errors
self.success_json = success_json
self.success_data = success_data

def request(self, *args, **kwargs):
if not self.num_errors:
return httplib2.Response(self.success_json), self.success_data
else:
self.num_errors -= 1
raise ssl.SSLError()


DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')


Expand Down Expand Up @@ -394,6 +409,20 @@ def test_media_io_base_download_handle_4xx(self):

self.assertEqual(self.fd.getvalue(), b'123')

def test_media_io_base_download_retries_ssl_errors(self):
self.request.http = HttpMockWithSSLErrors(
3, {'status': '200', 'content-range': '0-2/3'}, b'123')

download = MediaIoBaseDownload(
fd=self.fd, request=self.request, chunksize=3)
download._sleep = lambda _x: 0 # do nothing
download._rand = lambda: 10

status, done = download.next_chunk(num_retries=3)

self.assertEqual(self.fd.getvalue(), b'123')
self.assertEqual(True, done)

def test_media_io_base_download_retries_5xx(self):
self.request.http = HttpMockSequence([
({'status': '500'}, ''),
Expand Down Expand Up @@ -593,6 +622,36 @@ def test_unicode(self):
self.assertEqual(method, http.method)
self.assertEqual(str, type(http.method))

def test_retry_ssl_errors_non_resumable(self):
model = JsonModel()
request = HttpRequest(
HttpMockWithSSLErrors(3, {'status': '200'}, '{"foo": "bar"}'),
model.response,
u'https://www.example.com/json_api_endpoint')
request._sleep = lambda _x: 0 # do nothing
request._rand = lambda: 10
response = request.execute(num_retries=3)
self.assertEqual({u'foo': u'bar'}, response)

def test_retry_ssl_errors_resumable(self):
with open(datafile('small.png'), 'rb') as small_png_file:
small_png_fd = BytesIO(small_png_file.read())
upload = MediaIoBaseUpload(fd=small_png_fd, mimetype='image/png',
chunksize=500, resumable=True)
model = JsonModel()

request = HttpRequest(
HttpMockWithSSLErrors(
3, {'status': '200', 'location': 'location'}, '{"foo": "bar"}'),
model.response,
u'https://www.example.com/file_upload',
method='POST',
resumable=upload)
request._sleep = lambda _x: 0 # do nothing
request._rand = lambda: 10
response = request.execute(num_retries=3)
self.assertEqual({u'foo': u'bar'}, response)

def test_retry(self):
num_retries = 5
resp_seq = [({'status': '500'}, '')] * num_retries
Expand Down