diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 6424fc1b..ed85dc47 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -1,3 +1,4 @@ +import io import logging import urllib.parse from datetime import timedelta @@ -130,6 +131,14 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: flattened = dict(flatten_dict(with_fixed_bools)) return flattened + @staticmethod + def _is_seekable_stream(data) -> bool: + if data is None: + return False + if not isinstance(data, io.IOBase): + return False + return data.seekable() + def do(self, method: str, url: str, @@ -144,18 +153,31 @@ def do(self, if headers is None: headers = {} headers['User-Agent'] = self._user_agent_base - retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), - is_retryable=self._is_retryable, - clock=self._clock) - response = retryable(self._perform)(method, - url, - query=query, - headers=headers, - body=body, - raw=raw, - files=files, - data=data, - auth=auth) + + # Wrap strings and bytes in a seekable stream so that we can rewind them. + if isinstance(data, (str, bytes)): + data = io.BytesIO(data.encode('utf-8') if isinstance(data, str) else data) + + # Only retry if the request is not a stream or if the stream is seekable and + # we can rewind it. This is necessary to avoid bugs where the retry doesn't + # re-read already read data from the body. + if data is not None and not self._is_seekable_stream(data): + logger.debug(f"Retry disabled for non-seekable stream: type={type(data)}") + call = self._perform + else: + call = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), + is_retryable=self._is_retryable, + clock=self._clock)(self._perform) + + response = call(method, + url, + query=query, + headers=headers, + body=body, + raw=raw, + files=files, + data=data, + auth=auth) resp = dict() for header in response_headers if response_headers else []: @@ -226,6 +248,12 @@ def _perform(self, files=None, data=None, auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): + # Keep track of the initial position of the stream so that we can rewind it if + # we need to retry the request. + initial_data_position = 0 + if self._is_seekable_stream(data): + initial_data_position = data.tell() + response = self._session.request(method, url, params=self._fix_query_string(query), @@ -237,9 +265,18 @@ def _perform(self, stream=raw, timeout=self._http_timeout_seconds) self._record_request_log(response, raw=raw or data is not None or files is not None) + error = self._error_parser.get_api_error(response) if error is not None: + # If the request body is a seekable stream, rewind it so that it is ready + # to be read again in case of a retry. + # + # TODO: This should be moved into a "before-retry" hook to avoid one + # unnecessary seek on the last failed retry before aborting. + if self._is_seekable_stream(data): + data.seek(initial_data_position) raise error from None + return response def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: diff --git a/tests/test_base_client.py b/tests/test_base_client.py index 1e133b8f..4b6aaa71 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -1,3 +1,4 @@ +import io import random from http.server import BaseHTTPRequestHandler from typing import Iterator, List @@ -316,3 +317,131 @@ def mock_iter_content(chunk_size): assert received_data == test_data # all data was received correctly assert len(content_chunks) == expected_chunks # correct number of chunks assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size + + +def test_is_seekable_stream(): + client = _BaseClient() + + # Test various input types that are not streams. + assert not client._is_seekable_stream(None) # None + assert not client._is_seekable_stream("string data") # str + assert not client._is_seekable_stream(b"binary data") # bytes + assert not client._is_seekable_stream(["list", "data"]) # list + assert not client._is_seekable_stream(42) # int + + # Test non-seekable stream. + non_seekable = io.BytesIO(b"test data") + non_seekable.seekable = lambda: False + assert not client._is_seekable_stream(non_seekable) + + # Test seekable streams. + assert client._is_seekable_stream(io.BytesIO(b"test data")) # BytesIO + assert client._is_seekable_stream(io.StringIO("test data")) # StringIO + + # Test file objects. + with open(__file__, 'rb') as f: + assert client._is_seekable_stream(f) # File object + + # Test custom seekable stream. + class CustomSeekableStream(io.IOBase): + + def seekable(self): + return True + + def seek(self, offset, whence=0): + return 0 + + def tell(self): + return 0 + + assert client._is_seekable_stream(CustomSeekableStream()) + + +@pytest.mark.parametrize( + 'input_data', + [ + b"0123456789", # bytes -> BytesIO + "0123456789", # str -> BytesIO + io.BytesIO(b"0123456789"), # BytesIO directly + io.StringIO("0123456789"), # StringIO + ]) +def test_reset_seekable_stream_on_retry(input_data): + received_data = [] + + # Retry two times before succeeding. + def inner(h: BaseHTTPRequestHandler): + if len(received_data) == 2: + h.send_response(200) + h.end_headers() + else: + h.send_response(429) + h.end_headers() + + content_length = int(h.headers.get('Content-Length', 0)) + if content_length > 0: + received_data.append(h.rfile.read(content_length)) + + with http_fixture_server(inner) as host: + client = _BaseClient() + + # Retries should reset the stream. + client.do('POST', f'{host}/foo', data=input_data) + + assert received_data == [b"0123456789", b"0123456789", b"0123456789"] + + +def test_reset_seekable_stream_to_their_initial_position_on_retry(): + received_data = [] + + # Retry two times before succeeding. + def inner(h: BaseHTTPRequestHandler): + if len(received_data) == 2: + h.send_response(200) + h.end_headers() + else: + h.send_response(429) + h.end_headers() + + content_length = int(h.headers.get('Content-Length', 0)) + if content_length > 0: + received_data.append(h.rfile.read(content_length)) + + input_data = io.BytesIO(b"0123456789") + input_data.seek(4) + + with http_fixture_server(inner) as host: + client = _BaseClient() + + # Retries should reset the stream. + client.do('POST', f'{host}/foo', data=input_data) + + assert received_data == [b"456789", b"456789", b"456789"] + assert input_data.tell() == 10 # EOF + + +def test_no_retry_or_reset_on_non_seekable_stream(): + requests = [] + + # Always respond with a response that triggers a retry. + def inner(h: BaseHTTPRequestHandler): + content_length = int(h.headers.get('Content-Length', 0)) + if content_length > 0: + requests.append(h.rfile.read(content_length)) + + h.send_response(429) + h.send_header('Retry-After', '1') + h.end_headers() + + input_data = io.BytesIO(b"0123456789") + input_data.seekable = lambda: False # makes the stream appear non-seekable + + with http_fixture_server(inner) as host: + client = _BaseClient() + + # Should raise error immediately without retry. + with pytest.raises(DatabricksError): + client.do('POST', f'{host}/foo', data=input_data) + + # Verify that only one request was made (no retries). + assert requests == [b"0123456789"] + assert input_data.tell() == 10 # EOF