Skip to content

Commit

Permalink
[Fix] Rewind seekable streams before retrying (#821)
Browse files Browse the repository at this point in the history
## What changes are proposed in this pull request?

This PR adapts the retry mechanism of `BaseClient` to only retry if (i)
the request is not a stream or (ii) the stream is seekable and can be
reset to its initial position. This fixes a bug that led retries to
ignore part of the request that were already processed in previous
attempts.

## How is this tested?

Added unit tests to verify that (i) non-seekable streams are not
retried, and (ii) seekable streams are properly reset before retrying.
  • Loading branch information
renaudhartert-db authored Nov 15, 2024
1 parent ee6e70a commit e8b7916
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 12 deletions.
61 changes: 49 additions & 12 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import logging
import urllib.parse
from datetime import timedelta
Expand Down Expand Up @@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
flattened = dict(flatten_dict(with_fixed_bools))
return flattened

@staticmethod
def _is_seekable_stream(data) -> bool:
if data is None:
return False
if not isinstance(data, io.IOBase):
return False
return data.seekable()

def do(self,
method: str,
url: str,
Expand All @@ -144,18 +153,31 @@ def do(self,
if headers is None:
headers = {}
headers['User-Agent'] = self._user_agent_base
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)
response = retryable(self._perform)(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

# Wrap strings and bytes in a seekable stream so that we can rewind them.
if isinstance(data, (str, bytes)):
data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data)

# Only retry if the request is not a stream or if the stream is seekable and
# we can rewind it. This is necessary to avoid bugs where the retry doesn't
# re-read already read data from the body.
if data is not None and not self._is_seekable_stream(data):
logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}")
call = self._perform
else:
call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds),
is_retryable=self._is_retryable,
clock=self._clock)(self._perform)

response = call(method,
url,
query=query,
headers=headers,
body=body,
raw=raw,
files=files,
data=data,
auth=auth)

resp = dict()
for header in response_headers if response_headers else []:
Expand Down Expand Up @@ -226,6 +248,12 @@ def _perform(self,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None):
# Keep track of the initial position of the stream so that we can rewind it if
# we need to retry the request.
initial_data_position = 0
if self._is_seekable_stream(data):
initial_data_position = data.tell()

response = self._session.request(method,
url,
params=self._fix_query_string(query),
Expand All @@ -237,9 +265,18 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)

error = self._error_parser.get_api_error(response)
if error is not None:
# If the request body is a seekable stream, rewind it so that it is ready
# to be read again in case of a retry.
#
# TODO: This should be moved into a "before-retry" hook to avoid one
# unnecessary seek on the last failed retry before aborting.
if self._is_seekable_stream(data):
data.seek(initial_data_position)
raise error from None

return response

def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:
Expand Down
129 changes: 129 additions & 0 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import random
from http.server import BaseHTTPRequestHandler
from typing import Iterator, List
Expand Down Expand Up @@ -316,3 +317,131 @@ def mock_iter_content(chunk_size):
assert received_data == test_data # all data was received correctly
assert len(content_chunks) == expected_chunks # correct number of chunks
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size


def test_is_seekable_stream():
client = _BaseClient()

# Test various input types that are not streams.
assert not client._is_seekable_stream(None) # None
assert not client._is_seekable_stream("string data") # str
assert not client._is_seekable_stream(b"binary data") # bytes
assert not client._is_seekable_stream(["list", "data"]) # list
assert not client._is_seekable_stream(42) # int

# Test non-seekable stream.
non_seekable = io.BytesIO(b"test data")
non_seekable.seekable = lambda: False
assert not client._is_seekable_stream(non_seekable)

# Test seekable streams.
assert client._is_seekable_stream(io.BytesIO(b"test data")) # BytesIO
assert client._is_seekable_stream(io.StringIO("test data")) # StringIO

# Test file objects.
with open(__file__, 'rb') as f:
assert client._is_seekable_stream(f) # File object

# Test custom seekable stream.
class CustomSeekableStream(io.IOBase):

def seekable(self):
return True

def seek(self, offset, whence=0):
return 0

def tell(self):
return 0

assert client._is_seekable_stream(CustomSeekableStream())


@pytest.mark.parametrize(
'input_data',
[
b"0123456789", # bytes -> BytesIO
"0123456789", # str -> BytesIO
io.BytesIO(b"0123456789"), # BytesIO directly
io.StringIO("0123456789"), # StringIO
])
def test_reset_seekable_stream_on_retry(input_data):
received_data = []

# Retry two times before succeeding.
def inner(h: BaseHTTPRequestHandler):
if len(received_data) == 2:
h.send_response(200)
h.end_headers()
else:
h.send_response(429)
h.end_headers()

content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

with http_fixture_server(inner) as host:
client = _BaseClient()

# Retries should reset the stream.
client.do('POST', f'{host}/foo', data=input_data)

assert received_data == [b"0123456789", b"0123456789", b"0123456789"]


def test_reset_seekable_stream_to_their_initial_position_on_retry():
received_data = []

# Retry two times before succeeding.
def inner(h: BaseHTTPRequestHandler):
if len(received_data) == 2:
h.send_response(200)
h.end_headers()
else:
h.send_response(429)
h.end_headers()

content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

input_data = io.BytesIO(b"0123456789")
input_data.seek(4)

with http_fixture_server(inner) as host:
client = _BaseClient()

# Retries should reset the stream.
client.do('POST', f'{host}/foo', data=input_data)

assert received_data == [b"456789", b"456789", b"456789"]
assert input_data.tell() == 10 # EOF


def test_no_retry_or_reset_on_non_seekable_stream():
requests = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
requests.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

input_data = io.BytesIO(b"0123456789")
input_data.seekable = lambda: False # makes the stream appear non-seekable

with http_fixture_server(inner) as host:
client = _BaseClient()

# Should raise error immediately without retry.
with pytest.raises(DatabricksError):
client.do('POST', f'{host}/foo', data=input_data)

# Verify that only one request was made (no retries).
assert requests == [b"0123456789"]
assert input_data.tell() == 10 # EOF

0 comments on commit e8b7916

Please sign in to comment.