diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index e028e4b15..77e8c9aac 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -10,7 +10,7 @@ from .config import * # To preserve backwards compatibility (as these definitions were previously in this module) from .credentials_provider import * -from .errors import DatabricksError, get_api_error +from .errors import DatabricksError, _ErrorCustomizer, _Parser from .logger import RoundTrip from .oauth import retrieve_token from .retries import retried @@ -71,6 +71,8 @@ def __init__(self, cfg: Config = None): # Default to 60 seconds self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60 + self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)]) + @property def account_id(self) -> str: return self._cfg.account_id @@ -219,27 +221,6 @@ def _is_retryable(err: BaseException) -> Optional[str]: return f'matched {substring}' return None - @classmethod - def _parse_retry_after(cls, response: requests.Response) -> Optional[int]: - retry_after = response.headers.get("Retry-After") - if retry_after is None: - # 429 requests should include a `Retry-After` header, but if it's missing, - # we default to 1 second. - return cls._RETRY_AFTER_DEFAULT - # If the request is throttled, try parse the `Retry-After` header and sleep - # for the specified number of seconds. Note that this header can contain either - # an integer or a RFC1123 datetime string. - # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After - # - # For simplicity, we only try to parse it as an integer, as this is what Databricks - # platform returns. Otherwise, we fall back and don't sleep. - try: - return int(retry_after) - except ValueError: - logger.debug(f'Invalid Retry-After header received: {retry_after}. Defaulting to 1') - # defaulting to 1 sleep second to make self._is_retryable() simpler - return cls._RETRY_AFTER_DEFAULT - def _perform(self, method: str, url: str, @@ -261,15 +242,8 @@ def _perform(self, stream=raw, timeout=self._http_timeout_seconds) self._record_request_log(response, raw=raw or data is not None or files is not None) - error = get_api_error(response) + error = self._error_parser.get_api_error(response) if error is not None: - status_code = response.status_code - is_http_unauthorized_or_forbidden = status_code in (401, 403) - is_too_many_requests_or_unavailable = status_code in (429, 503) - if is_http_unauthorized_or_forbidden: - error.message = self._cfg.wrap_debug_info(error.message) - if is_too_many_requests_or_unavailable: - error.retry_after_secs = self._parse_retry_after(response) raise error from None return response @@ -279,6 +253,19 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) -> logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate()) +class _AddDebugErrorCustomizer(_ErrorCustomizer): + """An error customizer that adds debug information about the configuration to unauthenticated and + unauthorized errors.""" + + def __init__(self, cfg: Config): + self._cfg = cfg + + def customize_error(self, response: requests.Response, kwargs: dict): + if response.status_code in (401, 403): + message = kwargs.get('message', 'request failed') + kwargs['message'] = self._cfg.wrap_debug_info(message) + + class StreamingResponse(BinaryIO): _response: requests.Response _buffer: bytes diff --git a/databricks/sdk/errors/__init__.py b/databricks/sdk/errors/__init__.py index 578406803..8ad5ac708 100644 --- a/databricks/sdk/errors/__init__.py +++ b/databricks/sdk/errors/__init__.py @@ -1,6 +1,6 @@ from .base import DatabricksError, ErrorDetail -from .mapper import _error_mapper -from .parser import get_api_error +from .customizer import _ErrorCustomizer +from .parser import _Parser from .platform import * from .private_link import PrivateLinkValidationError from .sdk import * diff --git a/databricks/sdk/errors/customizer.py b/databricks/sdk/errors/customizer.py new file mode 100644 index 000000000..5c895becc --- /dev/null +++ b/databricks/sdk/errors/customizer.py @@ -0,0 +1,50 @@ +import abc +import logging + +import requests + + +class _ErrorCustomizer(abc.ABC): + """A customizer for errors from the Databricks REST API.""" + + @abc.abstractmethod + def customize_error(self, response: requests.Response, kwargs: dict): + """Customize the error constructor parameters.""" + + +class _RetryAfterCustomizer(_ErrorCustomizer): + """An error customizer that sets the retry_after_secs parameter based on the Retry-After header.""" + + _DEFAULT_RETRY_AFTER_SECONDS = 1 + """The default number of seconds to wait before retrying a request if the Retry-After header is missing or is not + a valid integer.""" + + @classmethod + def _parse_retry_after(cls, response: requests.Response) -> int: + retry_after = response.headers.get("Retry-After") + if retry_after is None: + logging.debug( + f'No Retry-After header received in response with status code 429 or 503. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}' + ) + # 429 requests should include a `Retry-After` header, but if it's missing, + # we default to 1 second. + return cls._DEFAULT_RETRY_AFTER_SECONDS + # If the request is throttled, try parse the `Retry-After` header and sleep + # for the specified number of seconds. Note that this header can contain either + # an integer or a RFC1123 datetime string. + # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + # + # For simplicity, we only try to parse it as an integer, as this is what Databricks + # platform returns. Otherwise, we fall back and don't sleep. + try: + return int(retry_after) + except ValueError: + logging.debug( + f'Invalid Retry-After header received: {retry_after}. Defaulting to {cls._DEFAULT_RETRY_AFTER_SECONDS}' + ) + # defaulting to 1 sleep second to make self._is_retryable() simpler + return cls._DEFAULT_RETRY_AFTER_SECONDS + + def customize_error(self, response: requests.Response, kwargs: dict): + if response.status_code in (429, 503): + kwargs['retry_after_secs'] = self._parse_retry_after(response) diff --git a/databricks/sdk/errors/deserializer.py b/databricks/sdk/errors/deserializer.py new file mode 100644 index 000000000..4da01ee68 --- /dev/null +++ b/databricks/sdk/errors/deserializer.py @@ -0,0 +1,106 @@ +import abc +import json +import logging +import re +from typing import Optional + +import requests + + +class _ErrorDeserializer(abc.ABC): + """A parser for errors from the Databricks REST API.""" + + @abc.abstractmethod + def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: + """Parses an error from the Databricks REST API. If the error cannot be parsed, returns None.""" + + +class _EmptyDeserializer(_ErrorDeserializer): + """A parser that handles empty responses.""" + + def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: + if len(response_body) == 0: + return {'message': response.reason} + return None + + +class _StandardErrorDeserializer(_ErrorDeserializer): + """ + Parses errors from the Databricks REST API using the standard error format. + """ + + def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: + try: + payload_str = response_body.decode('utf-8') + resp = json.loads(payload_str) + except UnicodeDecodeError as e: + logging.debug('_StandardErrorParser: unable to decode response using utf-8', exc_info=e) + return None + except json.JSONDecodeError as e: + logging.debug('_StandardErrorParser: unable to deserialize response as json', exc_info=e) + return None + if not isinstance(resp, dict): + logging.debug('_StandardErrorParser: response is valid JSON but not a dictionary') + return None + + error_args = { + 'message': resp.get('message', 'request failed'), + 'error_code': resp.get('error_code'), + 'details': resp.get('details'), + } + + # Handle API 1.2-style errors + if 'error' in resp: + error_args['message'] = resp['error'] + + # Handle SCIM Errors + detail = resp.get('detail') + status = resp.get('status') + scim_type = resp.get('scimType') + if detail: + # Handle SCIM error message details + # @see https://tools.ietf.org/html/rfc7644#section-3.7.3 + if detail == "null": + detail = "SCIM API Internal Error" + error_args['message'] = f"{scim_type} {detail}".strip(" ") + error_args['error_code'] = f"SCIM_{status}" + return error_args + + +class _StringErrorDeserializer(_ErrorDeserializer): + """ + Parses errors from the Databricks REST API in the format "ERROR_CODE: MESSAGE". + """ + + __STRING_ERROR_REGEX = re.compile(r'([A-Z_]+): (.*)') + + def deserialize_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: + payload_str = response_body.decode('utf-8') + match = self.__STRING_ERROR_REGEX.match(payload_str) + if not match: + logging.debug('_StringErrorParser: unable to parse response as string') + return None + error_code, message = match.groups() + return {'error_code': error_code, 'message': message, 'status': response.status_code, } + + +class _HtmlErrorDeserializer(_ErrorDeserializer): + """ + Parses errors from the Databricks REST API in HTML format. + """ + + __HTML_ERROR_REGEXES = [re.compile(r'
(.*)'), re.compile(r'
tag found in error response') + return None diff --git a/databricks/sdk/errors/parser.py b/databricks/sdk/errors/parser.py index 3d15f1673..3408964fe 100644 --- a/databricks/sdk/errors/parser.py +++ b/databricks/sdk/errors/parser.py @@ -1,115 +1,32 @@ -import abc -import json import logging -import re -from typing import Optional +from typing import List, Optional import requests from ..logger import RoundTrip from .base import DatabricksError +from .customizer import _ErrorCustomizer, _RetryAfterCustomizer +from .deserializer import (_EmptyDeserializer, _ErrorDeserializer, + _HtmlErrorDeserializer, _StandardErrorDeserializer, + _StringErrorDeserializer) from .mapper import _error_mapper from .private_link import (_get_private_link_validation_error, _is_private_link_redirect) +# A list of _ErrorDeserializers that are tried in order to parse an API error from a response body. Most errors should +# be parsable by the _StandardErrorDeserializer, but additional parsers can be added here for specific error formats. +# The order of the parsers is not important, as the set of errors that can be parsed by each parser should be disjoint. +_error_deserializers = [ + _EmptyDeserializer(), + _StandardErrorDeserializer(), + _StringErrorDeserializer(), + _HtmlErrorDeserializer(), +] -class _ErrorParser(abc.ABC): - """A parser for errors from the Databricks REST API.""" - - @abc.abstractmethod - def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: - """Parses an error from the Databricks REST API. If the error cannot be parsed, returns None.""" - - -class _EmptyParser(_ErrorParser): - """A parser that handles empty responses.""" - - def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: - if len(response_body) == 0: - return {'message': response.reason} - return None - - -class _StandardErrorParser(_ErrorParser): - """ - Parses errors from the Databricks REST API using the standard error format. - """ - - def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: - try: - payload_str = response_body.decode('utf-8') - resp: dict = json.loads(payload_str) - except json.JSONDecodeError as e: - logging.debug('_StandardErrorParser: unable to deserialize response as json', exc_info=e) - return None - - error_args = { - 'message': resp.get('message', 'request failed'), - 'error_code': resp.get('error_code'), - 'details': resp.get('details'), - } - - # Handle API 1.2-style errors - if 'error' in resp: - error_args['message'] = resp['error'] - - # Handle SCIM Errors - detail = resp.get('detail') - status = resp.get('status') - scim_type = resp.get('scimType') - if detail: - # Handle SCIM error message details - # @see https://tools.ietf.org/html/rfc7644#section-3.7.3 - if detail == "null": - detail = "SCIM API Internal Error" - error_args['message'] = f"{scim_type} {detail}".strip(" ") - error_args['error_code'] = f"SCIM_{status}" - return error_args - - -class _StringErrorParser(_ErrorParser): - """ - Parses errors from the Databricks REST API in the format "ERROR_CODE: MESSAGE". - """ - - __STRING_ERROR_REGEX = re.compile(r'([A-Z_]+): (.*)') - - def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: - payload_str = response_body.decode('utf-8') - match = self.__STRING_ERROR_REGEX.match(payload_str) - if not match: - logging.debug('_StringErrorParser: unable to parse response as string') - return None - error_code, message = match.groups() - return {'error_code': error_code, 'message': message, 'status': response.status_code, } - - -class _HtmlErrorParser(_ErrorParser): - """ - Parses errors from the Databricks REST API in HTML format. - """ - - __HTML_ERROR_REGEXES = [re.compile(r'(.*)'), re.compile(r'(.*) '), ] - - def parse_error(self, response: requests.Response, response_body: bytes) -> Optional[dict]: - payload_str = response_body.decode('utf-8') - for regex in self.__HTML_ERROR_REGEXES: - match = regex.search(payload_str) - if match: - message = match.group(1) if match.group(1) else response.reason - return { - 'status': response.status_code, - 'message': message, - 'error_code': response.reason.upper().replace(' ', '_') - } - logging.debug('_HtmlErrorParser: notag found in error response') - return None - - -# A list of ErrorParsers that are tried in order to parse an API error from a response body. Most errors should be -# parsable by the _StandardErrorParser, but additional parsers can be added here for specific error formats. The order -# of the parsers is not important, as the set of errors that can be parsed by each parser should be disjoint. -_error_parsers = [_EmptyParser(), _StandardErrorParser(), _StringErrorParser(), _HtmlErrorParser(), ] +# A list of _ErrorCustomizers that are applied to the error arguments after they are parsed. Customizers can modify the +# error arguments in any way, including adding or removing fields. Customizers are applied in order, so later +# customizers can override the changes made by earlier customizers. +_error_customizers = [_RetryAfterCustomizer(), ] def _unknown_error(response: requests.Response) -> str: @@ -124,24 +41,43 @@ def _unknown_error(response: requests.Response) -> str: f'https://github.com/databricks/databricks-sdk-go/issues. Request log:```{request_log}```') -def get_api_error(response: requests.Response) -> Optional[DatabricksError]: +class _Parser: """ - Handles responses from the REST API and returns a DatabricksError if the response indicates an error. - :param response: The response from the REST API. - :return: A DatabricksError if the response indicates an error, otherwise None. + A parser for errors from the Databricks REST API. It attempts to deserialize an error using a sequence of + deserializers, and then customizes the deserialized error using a sequence of customizers. If the error cannot be + deserialized, it returns a generic error with debugging information and instructions to report the issue to the SDK + issue tracker. """ - if not response.ok: - content = response.content - for parser in _error_parsers: - try: - error_args = parser.parse_error(response, content) - if error_args: - return _error_mapper(response, error_args) - except Exception as e: - logging.debug(f'Error parsing response with {parser}, continuing', exc_info=e) - return _error_mapper(response, {'message': 'unable to parse response. ' + _unknown_error(response)}) - # Private link failures happen via a redirect to the login page. From a requests-perspective, the request - # is successful, but the response is not what we expect. We need to handle this case separately. - if _is_private_link_redirect(response): - return _get_private_link_validation_error(response.url) + def __init__(self, + extra_error_parsers: List[_ErrorDeserializer] = [], + extra_error_customizers: List[_ErrorCustomizer] = []): + self._error_parsers = _error_deserializers + (extra_error_parsers + if extra_error_parsers is not None else []) + self._error_customizers = _error_customizers + (extra_error_customizers + if extra_error_customizers is not None else []) + + def get_api_error(self, response: requests.Response) -> Optional[DatabricksError]: + """ + Handles responses from the REST API and returns a DatabricksError if the response indicates an error. + :param response: The response from the REST API. + :return: A DatabricksError if the response indicates an error, otherwise None. + """ + if not response.ok: + content = response.content + for parser in self._error_parsers: + try: + error_args = parser.deserialize_error(response, content) + if error_args: + for customizer in self._error_customizers: + customizer.customize_error(response, error_args) + return _error_mapper(response, error_args) + except Exception as e: + logging.debug(f'Error parsing response with {parser}, continuing', exc_info=e) + return _error_mapper(response, + {'message': 'unable to parse response. ' + _unknown_error(response)}) + + # Private link failures happen via a redirect to the login page. From a requests-perspective, the request + # is successful, but the response is not what we expect. We need to handle this case separately. + if _is_private_link_redirect(response): + return _get_private_link_validation_error(response.url) diff --git a/databricks/sdk/logger/round_trip_logger.py b/databricks/sdk/logger/round_trip_logger.py index f1d177aaa..1c0a47f08 100644 --- a/databricks/sdk/logger/round_trip_logger.py +++ b/databricks/sdk/logger/round_trip_logger.py @@ -48,7 +48,8 @@ def generate(self) -> str: # Raw streams with `Transfer-Encoding: chunked` do not have `Content-Type` header sb.append("< [raw stream]") elif self._response.content: - sb.append(self._redacted_dump("< ", self._response.content.decode('utf-8'))) + decoded = self._response.content.decode('utf-8', errors='replace') + sb.append(self._redacted_dump("< ", decoded)) return '\n'.join(sb) @staticmethod diff --git a/tests/test_core.py b/tests/test_core.py index cc7926a72..d54563d4e 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -13,7 +13,7 @@ import pytest import requests -from databricks.sdk import WorkspaceClient +from databricks.sdk import WorkspaceClient, errors from databricks.sdk.core import (ApiClient, Config, DatabricksError, StreamingResponse) from databricks.sdk.credentials_provider import (CliTokenSource, @@ -359,8 +359,8 @@ def test_deletes(config, requests_mock): assert res is None -def test_error(config, requests_mock): - errorJson = { +@pytest.mark.parametrize('status_code,headers,body,expected_error', [ + (400, {}, { "message": "errorMessage", "details": [{ @@ -378,26 +378,74 @@ def test_error(config, requests_mock): "etag": "wrong etag" } }], - } - + }, + errors.BadRequest('errorMessage', + details=[{ + 'type': DatabricksError._error_info_type, + 'reason': 'error reason', + 'domain': 'error domain', + 'metadata': { + 'etag': 'error etag' + }, + }])), + (401, {}, { + 'error_code': 'UNAUTHORIZED', + 'message': 'errorMessage', + }, + errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', + error_code='UNAUTHORIZED')), + (403, {}, { + 'error_code': 'FORBIDDEN', + 'message': 'errorMessage', + }, + errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', + error_code='FORBIDDEN')), + (429, {}, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=1)), + (429, { + 'Retry-After': '100' + }, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), + (503, {}, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=1)), + (503, { + 'Retry-After': '100' + }, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, + errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=100)), + (404, {}, { + 'scimType': 'scim type', + 'detail': 'detail', + 'status': 'status', + }, errors.NotFound('scim type detail', error_code='SCIM_status')), +]) +def test_error(config, requests_mock, status_code, headers, body, expected_error): client = ApiClient(config) - requests_mock.get("/test", json=errorJson, status_code=400, ) + requests_mock.get("/test", json=body, status_code=status_code, headers=headers) with pytest.raises(DatabricksError) as raised: - client.do("GET", "/test", headers={"test": "test"}) - - error_infos = raised.value.get_error_info() - assert len(error_infos) == 1 - error_info = error_infos[0] - assert error_info.reason == "error reason" - assert error_info.domain == "error domain" - assert error_info.metadata["etag"] == "error etag" - assert error_info.type == DatabricksError._error_info_type - - -def test_error_with_scimType(): - args = {"detail": "detail", "scimType": "scim type"} - error = DatabricksError(**args) - assert str(error) == f"scim type detail" + client._perform("GET", "http://localhost/test", headers={"test": "test"}) + actual = raised.value + assert isinstance(actual, type(expected_error)) + assert str(actual) == str(expected_error) + assert actual.error_code == expected_error.error_code + assert actual.retry_after_secs == expected_error.retry_after_secs + expected_error_infos, actual_error_infos = expected_error.get_error_info(), actual.get_error_info() + assert len(expected_error_infos) == len(actual_error_infos) + for expected, actual in zip(expected_error_infos, actual_error_infos): + assert expected.type == actual.type + assert expected.reason == actual.reason + assert expected.domain == actual.domain + assert expected.metadata == actual.metadata @contextlib.contextmanager diff --git a/tests/test_errors.py b/tests/test_errors.py index 2e19ec897..881f016f3 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -12,13 +12,20 @@ def fake_response(method: str, status_code: int, response_body: str, path: Optional[str] = None) -> requests.Response: + return fake_raw_response(method, status_code, response_body.encode('utf-8'), path) + + +def fake_raw_response(method: str, + status_code: int, + response_body: bytes, + path: Optional[str] = None) -> requests.Response: resp = requests.Response() resp.status_code = status_code resp.reason = http.client.responses.get(status_code, '') if path is None: path = '/api/2.0/service' resp.request = requests.Request(method, f"https://databricks.com{path}").prepare() - resp._content = response_body.encode('utf-8') + resp._content = response_body return resp @@ -110,17 +117,22 @@ def make_private_link_response() -> requests.Response: 'https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n' '< 400 Bad Request\n' '< this is not a real response```')), - [ - fake_response( - 'GET', 404, - json.dumps({ - 'detail': 'Group with id 1234 is not found', - 'status': '404', - 'schemas': ['urn:ietf:params:scim:api:messages:2.0:Error'] - })), errors.NotFound, 'None Group with id 1234 is not found' - ]]) + (fake_response( + 'GET', 404, + json.dumps({ + 'detail': 'Group with id 1234 is not found', + 'status': '404', + 'schemas': ['urn:ietf:params:scim:api:messages:2.0:Error'] + })), errors.NotFound, 'None Group with id 1234 is not found'), + (fake_response('GET', 404, json.dumps("This is JSON but not a dictionary")), errors.NotFound, + 'unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. Please report this issue with the following debugging information to the SDK issue tracker at https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n< 404 Not Found\n< "This is JSON but not a dictionary"```' + ), + (fake_raw_response('GET', 404, b'\x80'), errors.NotFound, + 'unable to parse response. This is likely a bug in the Databricks SDK for Python or the underlying API. Please report this issue with the following debugging information to the SDK issue tracker at https://github.com/databricks/databricks-sdk-go/issues. Request log:```GET /api/2.0/service\n< 404 Not Found\n< �```' + )]) def test_get_api_error(response, expected_error, expected_message): + parser = errors._Parser() with pytest.raises(errors.DatabricksError) as e: - raise errors.get_api_error(response) + raise parser.get_api_error(response) assert isinstance(e.value, expected_error) assert str(e.value) == expected_message