Skip to content

Commit

Permalink
[Fix] Fix deserialization of 401/403 errors (databricks#758)
Browse files Browse the repository at this point in the history
## Changes
databricks#741 introduced a change to how an error message was modified in
`ApiClient._perform`. Previously, arguments to the DatabricksError
constructor were modified as a dictionary in `_perform`. After that
change, `get_api_error` started to return a `DatabricksError` instance
whose attributes were modified. The `message` attribute referred to in
that change does not exist in the DatabricksError class: there is a
`message` constructor parameter, but it is not set as an attribute.

This PR refactors the error handling logic slightly to restore the
original behavior. In doing this, we decouple all error-parsing and
customizing logic out of ApiClient. This also sets us up to allow for
further extension of error parsing and customization in the future, a
feature that I have seen present in other SDKs.

Fixes databricks#755.

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
mgyucht authored and aravind-segu committed Sep 18, 2024
1 parent 5afb9a2 commit 615dd98
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 184 deletions.
47 changes: 17 additions & 30 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .config import *
# To preserve backwards compatibility (as these definitions were previously in this module)
from .credentials_provider import *
from .errors import DatabricksError, get_api_error
from .errors import DatabricksError, _ErrorCustomizer, _Parser
from .logger import RoundTrip
from .oauth import retrieve_token
from .retries import retried
Expand Down Expand Up @@ -71,6 +71,8 @@ def __init__(self, cfg: Config = None):
# Default to 60 seconds
self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60

self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)])

@property
def account_id(self) -> str:
return self._cfg.account_id
Expand Down Expand Up @@ -219,27 +221,6 @@ def _is_retryable(err: BaseException) -> Optional[str]:
return f'matched {substring}'
return None

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._RETRY_AFTER_DEFAULT
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1')
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._RETRY_AFTER_DEFAULT

def _perform(self,
method: str,
url: str,
Expand All @@ -261,15 +242,8 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)
error = get_api_error(response)
error = self._error_parser.get_api_error(response)
if error is not None:
status_code = response.status_code
is_http_unauthorized_or_forbidden = status_code in (401, 403)
is_too_many_requests_or_unavailable = status_code in (429, 503)
if is_http_unauthorized_or_forbidden:
error.message = self._cfg.wrap_debug_info(error.message)
if is_too_many_requests_or_unavailable:
error.retry_after_secs = self._parse_retry_after(response)
raise error from None
return response

Expand All @@ -279,6 +253,19 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate())


class _AddDebugErrorCustomizer(_ErrorCustomizer):
"""An error customizer that adds debug information about the configuration to unauthenticated and
unauthorized errors."""

def __init__(self, cfg: Config):
self._cfg = cfg

def customize_error(self, response: requests.Response, kwargs: dict):
if response.status_code in (401, 403):
message = kwargs.get('message', 'request failed')
kwargs['message'] = self._cfg.wrap_debug_info(message)


class StreamingResponse(BinaryIO):
_response: requests.Response
_buffer: bytes
Expand Down
4 changes: 2 additions & 2 deletions databricks/sdk/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import DatabricksError, ErrorDetail
from .mapper import _error_mapper
from .parser import get_api_error
from .customizer import _ErrorCustomizer
from .parser import _Parser
from .platform import *
from .private_link import PrivateLinkValidationError
from .sdk import *
50 changes: 50 additions & 0 deletions databricks/sdk/errors/customizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import abc
import logging

import requests


class _ErrorCustomizer(abc.ABC):
"""A customizer for errors from the Databricks REST API."""

@abc.abstractmethod
def customize_error(self, response: requests.Response, kwargs: dict):
"""Customize the error constructor parameters."""


class _RetryAfterCustomizer(_ErrorCustomizer):
"""An error customizer that sets the retry_after_secs parameter based on the Retry-After header."""

_DEFAULT_RETRY_AFTER_SECONDS = 1
"""The default number of seconds to wait before retrying a request if the Retry-After header is missing or is not
a valid integer."""

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> int:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
logging.debug(
f'No Retry-After header received in response with status code 429 or 503. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}'
)
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._DEFAULT_RETRY_AFTER_SECONDS
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logging.debug(
f'Invalid Retry-After header received: {retry_after}. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}'
)
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._DEFAULT_RETRY_AFTER_SECONDS

def customize_error(self, response: requests.Response, kwargs: dict):
if response.status_code in (429, 503):
kwargs['retry_after_secs'] = self._parse_retry_after(response)
106 changes: 106 additions & 0 deletions databricks/sdk/errors/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import abc
import json
import logging
import re
from typing import Optional

import requests


class _ErrorDeserializer(abc.ABC):
"""A parser for errors from the Databricks REST API."""

@abc.abstractmethod
def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
"""Parses an error from the Databricks REST API. If the error cannot be parsed, returns None."""


class _EmptyDeserializer(_ErrorDeserializer):
"""A parser that handles empty responses."""

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
if len(response_body) == 0:
return {'message': response.reason}
return None


class _StandardErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API using the standard error format.
"""

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
try:
payload_str = response_body.decode('utf-8')
resp = json.loads(payload_str)
except UnicodeDecodeError as e:
logging.debug('_StandardErrorParser: unable to decode response using utf-8', exc_info=e)
return None
except json.JSONDecodeError as e:
logging.debug('_StandardErrorParser: unable to deserialize response as json', exc_info=e)
return None
if not isinstance(resp, dict):
logging.debug('_StandardErrorParser: response is valid JSON but not a dictionary')
return None

error_args = {
'message': resp.get('message', 'request failed'),
'error_code': resp.get('error_code'),
'details': resp.get('details'),
}

# Handle API 1.2-style errors
if 'error' in resp:
error_args['message'] = resp['error']

# Handle SCIM Errors
detail = resp.get('detail')
status = resp.get('status')
scim_type = resp.get('scimType')
if detail:
# Handle SCIM error message details
# @see https://tools.ietf.org/html/rfc7644#section-3.7.3
if detail == "null":
detail = "SCIM API Internal Error"
error_args['message'] = f"{scim_type} {detail}".strip(" ")
error_args['error_code'] = f"SCIM_{status}"
return error_args


class _StringErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API in the format "ERROR_CODE: MESSAGE".
"""

__STRING_ERROR_REGEX = re.compile(r'([A-Z_]+): (.*)')

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
payload_str = response_body.decode('utf-8')
match = self.__STRING_ERROR_REGEX.match(payload_str)
if not match:
logging.debug('_StringErrorParser: unable to parse response as string')
return None
error_code, message = match.groups()
return {'error_code': error_code, 'message': message, 'status': response.status_code, }


class _HtmlErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API in HTML format.
"""

__HTML_ERROR_REGEXES = [re.compile(r'<pre>(.*)</pre>'), re.compile(r'<title>(.*)</title>'), ]

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
payload_str = response_body.decode('utf-8')
for regex in self.__HTML_ERROR_REGEXES:
match = regex.search(payload_str)
if match:
message = match.group(1) if match.group(1) else response.reason
return {
'status': response.status_code,
'message': message,
'error_code': response.reason.upper().replace(' ', '_')
}
logging.debug('_HtmlErrorParser: no <pre> tag found in error response')
return None
Loading

0 comments on commit 615dd98

Please sign in to comment.