Skip to content

Commit

Permalink
more work
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyucht committed Sep 12, 2024
1 parent 95b0074 commit 35b1a3d
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 55 deletions.
46 changes: 16 additions & 30 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .config import *
# To preserve backwards compatibility (as these definitions were previously in this module)
from .credentials_provider import *
from .errors import DatabricksError, get_api_error
from .errors import DatabricksError, _Parser, _ErrorCustomizer
from .logger import RoundTrip
from .oauth import retrieve_token
from .retries import retried
Expand Down Expand Up @@ -71,6 +71,8 @@ def __init__(self, cfg: Config = None):
# Default to 60 seconds
self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60

self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)])

@property
def account_id(self) -> str:
return self._cfg.account_id
Expand Down Expand Up @@ -219,27 +221,6 @@ def _is_retryable(err: BaseException) -> Optional[str]:
return f'matched {substring}'
return None

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._RETRY_AFTER_DEFAULT
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1')
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._RETRY_AFTER_DEFAULT

def _perform(self,
method: str,
url: str,
Expand All @@ -261,15 +242,8 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)
error = get_api_error(response)
error = self._error_parser.get_api_error(response)
if error is not None:
status_code = response.status_code
is_http_unauthorized_or_forbidden = status_code in (401, 403)
is_too_many_requests_or_unavailable = status_code in (429, 503)
if is_http_unauthorized_or_forbidden:
error.message = self._cfg.wrap_debug_info(error.message)
if is_too_many_requests_or_unavailable:
error.retry_after_secs = self._parse_retry_after(response)
raise error from None
return response

Expand All @@ -279,6 +253,18 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate())


class _AddDebugErrorCustomizer(_ErrorCustomizer):
def __init__(self, cfg: Config):
self._cfg = cfg

def customize_error(self, response: requests.Response, kwargs: dict):
status_code = response.status_code
is_http_unauthorized_or_forbidden = status_code in (401, 403)
message = kwargs.get('message', 'request failed')
if is_http_unauthorized_or_forbidden:
kwargs['message'] = self._cfg.wrap_debug_info(message)


class StreamingResponse(BinaryIO):
_response: requests.Response
_buffer: bytes
Expand Down
2 changes: 1 addition & 1 deletion databricks/sdk/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import DatabricksError, ErrorDetail
from .mapper import _error_mapper
from .parser import get_api_error
from .parser import _Parser, _ErrorCustomizer
from .platform import *
from .private_link import PrivateLinkValidationError
from .sdk import *
1 change: 0 additions & 1 deletion databricks/sdk/errors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self,
message = f"{scimType} {message}".strip(" ")
error_code = f"SCIM_{status}"
super().__init__(message if message else error)
self.message = message
self.error_code = error_code
self.retry_after_secs = retry_after_secs
self.details = [ErrorDetail.from_dict(detail) for detail in details] if details else []
Expand Down
94 changes: 72 additions & 22 deletions databricks/sdk/errors/parser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import abc
import functools
import json
import logging
import re
from typing import Optional
from typing import Optional, List

import requests

Expand All @@ -21,6 +22,13 @@ def parse_error(self, response: requests.Response, response_body: bytes) -> Opti
"""Parses an error from the Databricks REST API. If the error cannot be parsed, returns None."""


class _ErrorCustomizer(abc.ABC):
"""A customizer for errors from the Databricks REST API."""

@abc.abstractmethod
def customize_error(self, response: requests.Response, kwargs: dict):
"""Customize the error constructor parameters."""

class _EmptyParser(_ErrorParser):
"""A parser that handles empty responses."""

Expand Down Expand Up @@ -106,10 +114,43 @@ def parse_error(self, response: requests.Response, response_body: bytes) -> Opti
return None


class _RetryAfterCustomizer(_ErrorCustomizer):
_RETRY_AFTER_DEFAULT = 1

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
logging.debug(f'No Retry-After header received in response with status code 429 or 503. Defaulting to {cls._RETRY_AFTER_DEFAULT}')
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._RETRY_AFTER_DEFAULT
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logging.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to {cls._RETRY_AFTER_DEFAULT}')
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._RETRY_AFTER_DEFAULT

def customize_error(self, response: requests.Response, kwargs: dict):
status_code = response.status_code
is_too_many_requests_or_unavailable = status_code in (429, 503)
if is_too_many_requests_or_unavailable:
kwargs['retry_after_secs'] = self._parse_retry_after(response)


# A list of ErrorParsers that are tried in order to parse an API error from a response body. Most errors should be
# parsable by the _StandardErrorParser, but additional parsers can be added here for specific error formats. The order
# of the parsers is not important, as the set of errors that can be parsed by each parser should be disjoint.
_error_parsers = [_EmptyParser(), _StandardErrorParser(), _StringErrorParser(), _HtmlErrorParser(), ]
_error_customizers = [_RetryAfterCustomizer(), ]


def _unknown_error(response: requests.Response) -> str:
Expand All @@ -124,24 +165,33 @@ def _unknown_error(response: requests.Response) -> str:
f'https://github.com/databricks/databricks-sdk-go/issues. Request log:```{request_log}```')


def get_api_error(response: requests.Response) -> Optional[DatabricksError]:
"""
Handles responses from the REST API and returns a DatabricksError if the response indicates an error.
:param response: The response from the REST API.
:return: A DatabricksError if the response indicates an error, otherwise None.
"""
if not response.ok:
content = response.content
for parser in _error_parsers:
try:
error_args = parser.parse_error(response, content)
if error_args:
return _error_mapper(response, error_args)
except Exception as e:
logging.debug(f'Error parsing response with {parser}, continuing', exc_info=e)
return _error_mapper(response, {'message': 'unable to parse response. ' + _unknown_error(response)})

# Private link failures happen via a redirect to the login page. From a requests-perspective, the request
# is successful, but the response is not what we expect. We need to handle this case separately.
if _is_private_link_redirect(response):
return _get_private_link_validation_error(response.url)
class _Parser:
def __init__(self,
extra_error_parsers: Optional[List[_ErrorParser]] = None,
extra_error_customizers: Optional[List[_ErrorCustomizer]] = None):
self._error_parsers = _error_parsers + extra_error_parsers
self._error_customizers = _error_customizers + extra_error_customizers

def get_api_error(self, response: requests.Response) -> Optional[DatabricksError]:
"""
Handles responses from the REST API and returns a DatabricksError if the response indicates an error.
:param response: The response from the REST API.
:return: A DatabricksError if the response indicates an error, otherwise None.
"""
if not response.ok:
content = response.content
for parser in self._error_parsers:
try:
error_args = parser.parse_error(response, content)
if error_args:
for customizer in self._error_customizers:
customizer.customize_error(response, error_args)
return _error_mapper(response, error_args)
except Exception as e:
logging.debug(f'Error parsing response with {parser}, continuing', exc_info=e)
return _error_mapper(response, {'message': 'unable to parse response. ' + _unknown_error(response)})

# Private link failures happen via a redirect to the login page. From a requests-perspective, the request
# is successful, but the response is not what we expect. We need to handle this case separately.
if _is_private_link_redirect(response):
return _get_private_link_validation_error(response.url)
1 change: 0 additions & 1 deletion tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,3 @@ def test_get_api_error(response, expected_error, expected_message):
raise errors.get_api_error(response)
assert isinstance(e.value, expected_error)
assert str(e.value) == expected_message
assert e.value.message == expected_message

0 comments on commit 35b1a3d

Please sign in to comment.