-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
988 additions
and
721 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
import logging | ||
from datetime import timedelta | ||
from types import TracebackType | ||
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, | ||
Optional, Type, Union) | ||
import urllib.parse | ||
|
||
import requests | ||
import requests.adapters | ||
|
||
from . import useragent | ||
from .casing import Casing | ||
from .clock import Clock, RealClock | ||
from .errors import DatabricksError, _ErrorCustomizer, _Parser | ||
from .logger import RoundTrip | ||
from .retries import retried | ||
|
||
logger = logging.getLogger('databricks.sdk') | ||
|
||
|
||
def fix_host_if_needed(host: Optional[str]) -> Optional[str]: | ||
if not host: | ||
return host | ||
|
||
# Add a default scheme if it's missing | ||
if '://' not in host: | ||
host = 'https://' + host | ||
|
||
o = urllib.parse.urlparse(host) | ||
# remove trailing slash | ||
path = o.path.rstrip('/') | ||
# remove port if 443 | ||
netloc = o.netloc | ||
if o.port == 443: | ||
netloc = netloc.split(':')[0] | ||
|
||
return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) | ||
|
||
|
||
class _BaseClient: | ||
|
||
def __init__(self, | ||
debug_truncate_bytes: int = None, | ||
retry_timeout_seconds: int = None, | ||
user_agent_base: str = None, | ||
header_factory: Callable[[], dict] = None, | ||
max_connection_pools: int = None, | ||
max_connections_per_pool: int = None, | ||
pool_block: bool = True, | ||
http_timeout_seconds: float = None, | ||
extra_error_customizers: List[_ErrorCustomizer] = None, | ||
debug_headers: bool = False, | ||
clock: Clock = None): | ||
""" | ||
:param debug_truncate_bytes: | ||
:param retry_timeout_seconds: | ||
:param user_agent_base: | ||
:param header_factory: A function that returns a dictionary of headers to include in the request. | ||
:param max_connection_pools: Number of urllib3 connection pools to cache before discarding the least | ||
recently used pool. Python requests default value is 10. | ||
:param max_connections_per_pool: The maximum number of connections to save in the pool. Improves performance | ||
in multithreaded situations. For now, we're setting it to the same value as connection_pool_size. | ||
:param pool_block: If pool_block is False, then more connections will are created, but not saved after the | ||
first use. Blocks when no free connections are available. urllib3 ensures that no more than | ||
pool_maxsize connections are used at a time. Prevents platform from flooding. By default, requests library | ||
doesn't block. | ||
:param http_timeout_seconds: | ||
:param extra_error_customizers: | ||
:param debug_headers: Whether to include debug headers in the request log. | ||
:param clock: Clock object to use for time-related operations. | ||
""" | ||
|
||
self._debug_truncate_bytes = debug_truncate_bytes or 96 | ||
self._debug_headers = debug_headers | ||
self._retry_timeout_seconds = retry_timeout_seconds or 300 | ||
self._user_agent_base = user_agent_base or useragent.to_string() | ||
self._header_factory = header_factory | ||
self._clock = clock or RealClock() | ||
self._session = requests.Session() | ||
self._session.auth = self._authenticate | ||
|
||
# We don't use `max_retries` from HTTPAdapter to align with a more production-ready | ||
# retry strategy established in the Databricks SDK for Go. See _is_retryable and | ||
# @retried for more details. | ||
http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20, | ||
pool_maxsize=max_connection_pools or 20, | ||
pool_block=pool_block) | ||
self._session.mount("https://", http_adapter) | ||
|
||
# Default to 60 seconds | ||
self._http_timeout_seconds = http_timeout_seconds or 60 | ||
|
||
self._error_parser = _Parser(extra_error_customizers=extra_error_customizers) | ||
|
||
def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: | ||
if self._header_factory: | ||
headers = self._header_factory() | ||
for k, v in headers.items(): | ||
r.headers[k] = v | ||
return r | ||
|
||
@staticmethod | ||
def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: | ||
# Convert True -> "true" for Databricks APIs to understand booleans. | ||
# See: https://github.com/databricks/databricks-sdk-py/issues/142 | ||
if query is None: | ||
return None | ||
with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} | ||
|
||
# Query parameters may be nested, e.g. | ||
# {'filter_by': {'user_ids': [123, 456]}} | ||
# The HTTP-compatible representation of this is | ||
# filter_by.user_ids=123&filter_by.user_ids=456 | ||
# To achieve this, we convert the above dictionary to | ||
# {'filter_by.user_ids': [123, 456]} | ||
# See the following for more information: | ||
# https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule | ||
def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: | ||
for k1, v1 in d.items(): | ||
if isinstance(v1, dict): | ||
v1 = dict(flatten_dict(v1)) | ||
for k2, v2 in v1.items(): | ||
yield f"{k1}.{k2}", v2 | ||
else: | ||
yield k1, v1 | ||
|
||
flattened = dict(flatten_dict(with_fixed_bools)) | ||
return flattened | ||
|
||
def do(self, | ||
method: str, | ||
url: str, | ||
query: dict = None, | ||
headers: dict = None, | ||
body: dict = None, | ||
raw: bool = False, | ||
files=None, | ||
data=None, | ||
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, | ||
response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: | ||
if headers is None: | ||
headers = {} | ||
headers['User-Agent'] = self._user_agent_base | ||
retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), | ||
is_retryable=self._is_retryable, | ||
clock=self._clock) | ||
response = retryable(self._perform)(method, | ||
url, | ||
query=query, | ||
headers=headers, | ||
body=body, | ||
raw=raw, | ||
files=files, | ||
data=data, | ||
auth=auth) | ||
|
||
resp = dict() | ||
for header in response_headers if response_headers else []: | ||
resp[header] = response.headers.get(Casing.to_header_case(header)) | ||
if raw: | ||
resp["contents"] = _StreamingResponse(response) | ||
return resp | ||
if not len(response.content): | ||
return resp | ||
|
||
json_response = response.json() | ||
if json_response is None: | ||
return resp | ||
|
||
if isinstance(json_response, list): | ||
return json_response | ||
|
||
return {**resp, **json_response} | ||
|
||
@staticmethod | ||
def _is_retryable(err: BaseException) -> Optional[str]: | ||
# this method is Databricks-specific port of urllib3 retries | ||
# (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) | ||
# and Databricks SDK for Go retries | ||
# (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) | ||
from urllib3.exceptions import ProxyError | ||
if isinstance(err, ProxyError): | ||
err = err.original_error | ||
if isinstance(err, requests.ConnectionError): | ||
# corresponds to `connection reset by peer` and `connection refused` errors from Go, | ||
# which are generally related to the temporary glitches in the networking stack, | ||
# also caused by endpoint protection software, like ZScaler, to drop connections while | ||
# not yet authenticated. | ||
# | ||
# return a simple string for debug log readability, as `raise TimeoutError(...) from err` | ||
# will bubble up the original exception in case we reach max retries. | ||
return f'cannot connect' | ||
if isinstance(err, requests.Timeout): | ||
# corresponds to `TLS handshake timeout` and `i/o timeout` in Go. | ||
# | ||
# return a simple string for debug log readability, as `raise TimeoutError(...) from err` | ||
# will bubble up the original exception in case we reach max retries. | ||
return f'timeout' | ||
if isinstance(err, DatabricksError): | ||
message = str(err) | ||
transient_error_string_matches = [ | ||
"com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", | ||
"does not have any associated worker environments", "There is no worker environment with id", | ||
"Unknown worker environment", "ClusterNotReadyException", "Unexpected error", | ||
"Please try again later or try a faster operation.", | ||
"RPC token bucket limit has been exceeded", | ||
] | ||
for substring in transient_error_string_matches: | ||
if substring not in message: | ||
continue | ||
return f'matched {substring}' | ||
return None | ||
|
||
def _perform(self, | ||
method: str, | ||
url: str, | ||
query: dict = None, | ||
headers: dict = None, | ||
body: dict = None, | ||
raw: bool = False, | ||
files=None, | ||
data=None, | ||
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): | ||
response = self._session.request(method, | ||
url, | ||
params=self._fix_query_string(query), | ||
json=body, | ||
headers=headers, | ||
files=files, | ||
data=data, | ||
auth=auth, | ||
stream=raw, | ||
timeout=self._http_timeout_seconds) | ||
self._record_request_log(response, raw=raw or data is not None or files is not None) | ||
error = self._error_parser.get_api_error(response) | ||
if error is not None: | ||
raise error from None | ||
return response | ||
|
||
def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: | ||
if not logger.isEnabledFor(logging.DEBUG): | ||
return | ||
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) | ||
|
||
|
||
class _StreamingResponse(BinaryIO): | ||
_response: requests.Response | ||
_buffer: bytes | ||
_content: Union[Iterator[bytes], None] | ||
_chunk_size: Union[int, None] | ||
_closed: bool = False | ||
|
||
def fileno(self) -> int: | ||
pass | ||
|
||
def flush(self) -> int: | ||
pass | ||
|
||
def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): | ||
self._response = response | ||
self._buffer = b'' | ||
self._content = None | ||
self._chunk_size = chunk_size | ||
|
||
def _open(self) -> None: | ||
if self._closed: | ||
raise ValueError("I/O operation on closed file") | ||
if not self._content: | ||
self._content = self._response.iter_content(chunk_size=self._chunk_size) | ||
|
||
def __enter__(self) -> BinaryIO: | ||
self._open() | ||
return self | ||
|
||
def set_chunk_size(self, chunk_size: Union[int, None]) -> None: | ||
self._chunk_size = chunk_size | ||
|
||
def close(self) -> None: | ||
self._response.close() | ||
self._closed = True | ||
|
||
def isatty(self) -> bool: | ||
return False | ||
|
||
def read(self, n: int = -1) -> bytes: | ||
self._open() | ||
read_everything = n < 0 | ||
remaining_bytes = n | ||
res = b'' | ||
while remaining_bytes > 0 or read_everything: | ||
if len(self._buffer) == 0: | ||
try: | ||
self._buffer = next(self._content) | ||
except StopIteration: | ||
break | ||
bytes_available = len(self._buffer) | ||
to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) | ||
res += self._buffer[:to_read] | ||
self._buffer = self._buffer[to_read:] | ||
remaining_bytes -= to_read | ||
return res | ||
|
||
def readable(self) -> bool: | ||
return self._content is not None | ||
|
||
def readline(self, __limit: int = ...) -> bytes: | ||
raise NotImplementedError() | ||
|
||
def readlines(self, __hint: int = ...) -> List[bytes]: | ||
raise NotImplementedError() | ||
|
||
def seek(self, __offset: int, __whence: int = ...) -> int: | ||
raise NotImplementedError() | ||
|
||
def seekable(self) -> bool: | ||
return False | ||
|
||
def tell(self) -> int: | ||
raise NotImplementedError() | ||
|
||
def truncate(self, __size: Union[int, None] = ...) -> int: | ||
raise NotImplementedError() | ||
|
||
def writable(self) -> bool: | ||
return False | ||
|
||
def write(self, s: Union[bytes, bytearray]) -> int: | ||
raise NotImplementedError() | ||
|
||
def writelines(self, lines: Iterable[bytes]) -> None: | ||
raise NotImplementedError() | ||
|
||
def __next__(self) -> bytes: | ||
return self.read(1) | ||
|
||
def __iter__(self) -> Iterator[bytes]: | ||
return self._content | ||
|
||
def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], | ||
traceback: Union[TracebackType, None]) -> None: | ||
self._content = None | ||
self._buffer = b'' | ||
self.close() |
Oops, something went wrong.