Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix deserialization of 401/403 errors #758

Merged
merged 9 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 17 additions & 30 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .config import *
# To preserve backwards compatibility (as these definitions were previously in this module)
from .credentials_provider import *
from .errors import DatabricksError, get_api_error
from .errors import DatabricksError, _ErrorCustomizer, _Parser
from .logger import RoundTrip
from .oauth import retrieve_token
from .retries import retried
Expand Down Expand Up @@ -71,6 +71,8 @@ def __init__(self, cfg: Config = None):
# Default to 60 seconds
self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60

self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)])

@property
def account_id(self) -> str:
return self._cfg.account_id
Expand Down Expand Up @@ -219,27 +221,6 @@ def _is_retryable(err: BaseException) -> Optional[str]:
return f'matched {substring}'
return None

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> Optional[int]:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._RETRY_AFTER_DEFAULT
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1')
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._RETRY_AFTER_DEFAULT

def _perform(self,
method: str,
url: str,
Expand All @@ -261,15 +242,8 @@ def _perform(self,
stream=raw,
timeout=self._http_timeout_seconds)
self._record_request_log(response, raw=raw or data is not None or files is not None)
error = get_api_error(response)
error = self._error_parser.get_api_error(response)
if error is not None:
status_code = response.status_code
is_http_unauthorized_or_forbidden = status_code in (401, 403)
is_too_many_requests_or_unavailable = status_code in (429, 503)
if is_http_unauthorized_or_forbidden:
error.message = self._cfg.wrap_debug_info(error.message)
if is_too_many_requests_or_unavailable:
error.retry_after_secs = self._parse_retry_after(response)
raise error from None
return response

Expand All @@ -279,6 +253,19 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate())


class _AddDebugErrorCustomizer(_ErrorCustomizer):
"""An error customizer that adds debug information about the configuration to unauthenticated and
unauthorized errors."""

def __init__(self, cfg: Config):
self._cfg = cfg

def customize_error(self, response: requests.Response, kwargs: dict):
if response.status_code in (401, 403):
message = kwargs.get('message', 'request failed')
kwargs['message'] = self._cfg.wrap_debug_info(message)


class StreamingResponse(BinaryIO):
_response: requests.Response
_buffer: bytes
Expand Down
4 changes: 2 additions & 2 deletions databricks/sdk/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base import DatabricksError, ErrorDetail
from .mapper import _error_mapper
from .parser import get_api_error
from .customizer import _ErrorCustomizer
from .parser import _Parser
from .platform import *
from .private_link import PrivateLinkValidationError
from .sdk import *
50 changes: 50 additions & 0 deletions databricks/sdk/errors/customizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import abc
import logging

import requests


class _ErrorCustomizer(abc.ABC):
"""A customizer for errors from the Databricks REST API."""

@abc.abstractmethod
def customize_error(self, response: requests.Response, kwargs: dict):
"""Customize the error constructor parameters."""


class _RetryAfterCustomizer(_ErrorCustomizer):
"""An error customizer that sets the retry_after_secs parameter based on the Retry-After header."""

_DEFALT_RETRY_AFTER_SECONDS = 1
mgyucht marked this conversation as resolved.
Show resolved Hide resolved
"""The default number of seconds to wait before retrying a request if the Retry-After header is missing or is not
a valid integer."""

@classmethod
def _parse_retry_after(cls, response: requests.Response) -> int:
retry_after = response.headers.get("Retry-After")
if retry_after is None:
logging.debug(
f'No Retry-After header received in response with status code 429 or 503. Defaulting to {cls._DEFALT_RETRY_AFTER_SECONDS}'
)
# 429 requests should include a `Retry-After` header, but if it's missing,
# we default to 1 second.
return cls._DEFALT_RETRY_AFTER_SECONDS
# If the request is throttled, try parse the `Retry-After` header and sleep
# for the specified number of seconds. Note that this header can contain either
# an integer or a RFC1123 datetime string.
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# For simplicity, we only try to parse it as an integer, as this is what Databricks
# platform returns. Otherwise, we fall back and don't sleep.
try:
return int(retry_after)
except ValueError:
logging.debug(
f'Invalid Retry-After header received: {retry_after}. Defaulting to {cls._DEFALT_RETRY_AFTER_SECONDS}'
)
# defaulting to 1 sleep second to make self._is_retryable() simpler
return cls._DEFALT_RETRY_AFTER_SECONDS

def customize_error(self, response: requests.Response, kwargs: dict):
if response.status_code in (429, 503):
kwargs['retry_after_secs'] = self._parse_retry_after(response)
106 changes: 106 additions & 0 deletions databricks/sdk/errors/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import abc
import json
import logging
import re
from typing import Optional

import requests


class _ErrorDeserializer(abc.ABC):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to _ErrorDeserializer and moved here from parser.py.

"""A parser for errors from the Databricks REST API."""

@abc.abstractmethod
def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
"""Parses an error from the Databricks REST API. If the error cannot be parsed, returns None."""


class _EmptyDeserializer(_ErrorDeserializer):
"""A parser that handles empty responses."""

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
if len(response_body) == 0:
return {'message': response.reason}
return None


class _StandardErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API using the standard error format.
"""

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
try:
payload_str = response_body.decode('utf-8')
resp = json.loads(payload_str)
except UnicodeDecodeError as e:
logging.debug('_StandardErrorParser: unable to decode response using utf-8', exc_info=e)
return None
except json.JSONDecodeError as e:
logging.debug('_StandardErrorParser: unable to deserialize response as json', exc_info=e)
return None
if not isinstance(resp, dict):
logging.debug('_StandardErrorParser: response is valid JSON but not a dictionary')
return None

error_args = {
'message': resp.get('message', 'request failed'),
'error_code': resp.get('error_code'),
'details': resp.get('details'),
}

# Handle API 1.2-style errors
if 'error' in resp:
error_args['message'] = resp['error']

# Handle SCIM Errors
detail = resp.get('detail')
status = resp.get('status')
scim_type = resp.get('scimType')
if detail:
# Handle SCIM error message details
# @see https://tools.ietf.org/html/rfc7644#section-3.7.3
if detail == "null":
detail = "SCIM API Internal Error"
error_args['message'] = f"{scim_type} {detail}".strip(" ")
error_args['error_code'] = f"SCIM_{status}"
return error_args


class _StringErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API in the format "ERROR_CODE: MESSAGE".
"""

__STRING_ERROR_REGEX = re.compile(r'([A-Z_]+): (.*)')

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
payload_str = response_body.decode('utf-8')
match = self.__STRING_ERROR_REGEX.match(payload_str)
if not match:
logging.debug('_StringErrorParser: unable to parse response as string')
return None
error_code, message = match.groups()
return {'error_code': error_code, 'message': message, 'status': response.status_code, }


class _HtmlErrorDeserializer(_ErrorDeserializer):
"""
Parses errors from the Databricks REST API in HTML format.
"""

__HTML_ERROR_REGEXES = [re.compile(r'<pre>(.*)</pre>'), re.compile(r'<title>(.*)</title>'), ]

def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]:
payload_str = response_body.decode('utf-8')
for regex in self.__HTML_ERROR_REGEXES:
match = regex.search(payload_str)
if match:
message = match.group(1) if match.group(1) else response.reason
return {
'status': response.status_code,
'message': message,
'error_code': response.reason.upper().replace(' ', '_')
}
logging.debug('_HtmlErrorParser: no <pre> tag found in error response')
return None
Loading
Loading