Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix deserialization of 401/403 errors #758

Merged
merged 9 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 17 additions & 30 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .config import *
# To preserve backwards compatibility (as these definitions were previously in this module)
from .credentials_provider import *
from .errors import DatabricksError, get_api_error
from .errors import DatabricksError, _ErrorCustomizer, _Parser
from .logger import RoundTrip
from .oauth import retrieve_token
from .retries import retried
Expand Down Expand Up @@ -71,6 +71,8 @@ def __init__(self, cfg: Config = None):
# Default to 60 seconds
self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60

self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)])

@property
def account_id(self) -> str:
return self._cfg.account_id
Expand Down Expand Up @@ -219,27 +221,6 @@ def _is_retryable(err: BaseException) -> Optional[str]:
return f'matched {substring}'
return None

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._RETRY_AFTER_DEFAULT
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1')
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._RETRY_AFTER_DEFAULT

def _perform(self,
method: str,
url: str,
Expand All @@ -261,15 +242,8 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)
error = get_api_error(response)
error = self._error_parser.get_api_error(response)
if error is not None:
status_code = response.status_code
is_http_unauthorized_or_forbidden = status_code in (401, 403)
is_too_many_requests_or_unavailable = status_code in (429, 503)
if is_http_unauthorized_or_forbidden:
error.message = self._cfg.wrap_debug_info(error.message)
if is_too_many_requests_or_unavailable:
error.retry_after_secs = self._parse_retry_after(response)
raise error from None
return response

Expand All @@ -279,6 +253,19 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate())


class _AddDebugErrorCustomizer(_ErrorCustomizer):

def __init__(self, cfg: Config):
self._cfg = cfg

def customize_error(self, response: requests.Response, kwargs: dict):
status_code = response.status_code
is_http_unauthorized_or_forbidden = status_code in (401, 403)
message = kwargs.get('message', 'request failed')
if is_http_unauthorized_or_forbidden:
kwargs['message'] = self._cfg.wrap_debug_info(message)
mgyucht marked this conversation as resolved.
Show resolved Hide resolved


class StreamingResponse(BinaryIO):
_response: requests.Response
_buffer: bytes
Expand Down
4 changes: 2 additions & 2 deletions databricks/sdk/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import DatabricksError, ErrorDetail
from .mapper import _error_mapper
from .parser import get_api_error
from .customizer import _ErrorCustomizer
from .parser import _Parser
from .platform import *
from .private_link import PrivateLinkValidationError
from .sdk import *
49 changes: 49 additions & 0 deletions databricks/sdk/errors/customizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import abc
import logging
from typing import Optional

import requests


class _ErrorCustomizer(abc.ABC):
"""A customizer for errors from the Databricks REST API."""

@abc.abstractmethod
def customize_error(self, response: requests.Response, kwargs: dict):
"""Customize the error constructor parameters."""


class _RetryAfterCustomizer(_ErrorCustomizer):
_RETRY_AFTER_DEFAULT = 1
mgyucht marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
mgyucht marked this conversation as resolved.
Show resolved Hide resolved
retry_after = response.headers.get("Retry-After")
if retry_after is None:
logging.debug(
f'No Retry-After header received in response with status code 429 or 503. Defaulting to {cls._RETRY_AFTER_DEFAULT}'
)
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._RETRY_AFTER_DEFAULT
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logging.debug(
f'Invalid Retry-After header received: {retry_after}. Defaulting to {cls._RETRY_AFTER_DEFAULT}'
)
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._RETRY_AFTER_DEFAULT

def customize_error(self, response: requests.Response, kwargs: dict):
status_code = response.status_code
is_too_many_requests_or_unavailable = status_code in (429, 503)
if is_too_many_requests_or_unavailable:
kwargs['retry_after_secs'] = self._parse_retry_after(response)
mgyucht marked this conversation as resolved.
Show resolved Hide resolved
100 changes: 100 additions & 0 deletions databricks/sdk/errors/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import abc
import json
import logging
import re
from typing import Optional

import requests


class _ErrorDeserializer(abc.ABC):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to _ErrorDeserializer and moved here from parser.py.

"""A parser for errors from the Databricks REST API."""

@abc.abstractmethod
def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
mgyucht marked this conversation as resolved.
Show resolved Hide resolved
"""Parses an error from the Databricks REST API. If the error cannot be parsed, returns None."""


class _EmptyDeserializer(_ErrorDeserializer):
"""A parser that handles empty responses."""

def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
if len(response_body) == 0:
return {'message': response.reason}
return None


class _StandardErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API using the standard error format.
"""

def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
try:
payload_str = response_body.decode('utf-8')
resp: dict = json.loads(payload_str)
except json.JSONDecodeError as e:
logging.debug('_StandardErrorParser: unable to deserialize response as json', exc_info=e)
return None

mgyucht marked this conversation as resolved.
Show resolved Hide resolved
error_args = {
'message': resp.get('message', 'request failed'),
'error_code': resp.get('error_code'),
'details': resp.get('details'),
}

# Handle API 1.2-style errors
if 'error' in resp:
error_args['message'] = resp['error']

# Handle SCIM Errors
detail = resp.get('detail')
status = resp.get('status')
scim_type = resp.get('scimType')
if detail:
# Handle SCIM error message details
# @see https://tools.ietf.org/html/rfc7644#section-3.7.3
if detail == "null":
detail = "SCIM API Internal Error"
error_args['message'] = f"{scim_type} {detail}".strip(" ")
error_args['error_code'] = f"SCIM_{status}"
return error_args


class _StringErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API in the format "ERROR_CODE: MESSAGE".
"""

__STRING_ERROR_REGEX = re.compile(r'([A-Z_]+): (.*)')

def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
payload_str = response_body.decode('utf-8')
match = self.__STRING_ERROR_REGEX.match(payload_str)
if not match:
logging.debug('_StringErrorParser: unable to parse response as string')
return None
error_code, message = match.groups()
return {'error_code': error_code, 'message': message, 'status': response.status_code, }


class _HtmlErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API in HTML format.
"""

__HTML_ERROR_REGEXES = [re.compile(r'<pre>(.*)</pre>'), re.compile(r'<title>(.*)</title>'), ]

def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
payload_str = response_body.decode('utf-8')
for regex in self.__HTML_ERROR_REGEXES:
match = regex.search(payload_str)
if match:
message = match.group(1) if match.group(1) else response.reason
return {
'status': response.status_code,
'message': message,
'error_code': response.reason.upper().replace(' ', '_')
}
logging.debug('_HtmlErrorParser: no <pre> tag found in error response')
return None
Loading
Loading