diff --git a/CHANGELOG.md b/CHANGELOG.md index 78574ee4d..9b5c0e242 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## Version 2.4.5 (2020-08-04) +### Added +* retry capabilities for common flaky API failures +* protection against improper types passed into `Project.upload_anntations` +* pass thru API error messages when possible + ## Version 2.4.3 (2020-08-04) ### Added diff --git a/labelbox/client.py b/labelbox/client.py index 6b0812f65..d81ad22b4 100644 --- a/labelbox/client.py +++ b/labelbox/client.py @@ -5,6 +5,7 @@ import os from typing import Tuple +from google.api_core import retry import requests import requests.exceptions @@ -60,6 +61,8 @@ def __init__(self, 'Authorization': 'Bearer %s' % api_key } + @retry.Retry(predicate=retry.if_exception_type( + labelbox.exceptions.InternalServerError)) def execute(self, query, params=None, timeout=10.0): """ Sends a request to the server for the execution of the given query. Checks the response for errors and wraps errors @@ -121,12 +124,15 @@ def convert_value(value): "Unknown error during Client.query(): " + str(e), e) try: - response = response.json() + r_json = response.json() except: + error_502 = '502 Bad Gateway' + if error_502 in response.text: + raise labelbox.exceptions.InternalServerError(error_502) raise labelbox.exceptions.LabelboxError( "Failed to parse response as JSON: %s" % response.text) - errors = response.get("errors", []) + errors = r_json.get("errors", []) def check_errors(keywords, *path): """ Helper that looks for any of the given `keywords` in any of @@ -166,16 +172,32 @@ def check_errors(keywords, *path): graphql_error["message"]) # Check if API limit was exceeded - response_msg = response.get("message", "") + response_msg = r_json.get("message", "") if response_msg.startswith("You have exceeded"): raise labelbox.exceptions.ApiLimitError(response_msg) + prisma_error = check_errors(["INTERNAL_SERVER_ERROR"], "extensions", + "code") + if prisma_error: + raise labelbox.exceptions.InternalServerError( + prisma_error["message"]) + if len(errors) > 0: logger.warning("Unparsed errors on query execution: %r", errors) raise labelbox.exceptions.LabelboxError("Unknown error: %s" % str(errors)) - return response["data"] + # if we do return a proper error code, and didn't catch this above + # reraise + # this mainly catches a 401 for API access disabled for free tier + # TODO: need to unify API errors to handle things more uniformly + # in the SDK + if response.status_code != requests.codes.ok: + message = f"{response.status_code} {response.reason}" + cause = r_json.get('message') + raise labelbox.exceptions.LabelboxError(message, cause) + + return r_json["data"] def upload_file(self, path: str) -> str: """Uploads given path to local file. diff --git a/labelbox/exceptions.py b/labelbox/exceptions.py index 3c6fae4fb..9dc8a75cf 100644 --- a/labelbox/exceptions.py +++ b/labelbox/exceptions.py @@ -48,6 +48,16 @@ class ValidationFailedError(LabelboxError): pass +class InternalServerError(LabelboxError): + """Nondescript prisma or 502 related errors. + + Meant to be retryable. + + TODO: these errors need better messages from platform + """ + pass + + class InvalidQueryError(LabelboxError): """ Indicates a malconstructed or unsupported query (either by GraphQL in general or by Labelbox specifically). This can be the result of either client diff --git a/labelbox/schema/bulk_import_request.py b/labelbox/schema/bulk_import_request.py index a5ebeeb3d..b1ff198ce 100644 --- a/labelbox/schema/bulk_import_request.py +++ b/labelbox/schema/bulk_import_request.py @@ -222,6 +222,9 @@ def create_from_objects(cls, client, project_id: str, name: str, """ _validate_ndjson(predictions) data_str = ndjson.dumps(predictions) + if not data_str: + raise ValueError('annotations cannot be empty') + data = data_str.encode('utf-8') file_name = _make_file_name(project_id, name) request_data = _make_request_data(project_id, name, len(data_str), diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index bd35d7306..d32ca6be1 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -410,13 +410,16 @@ def _is_url_valid(url: Union[str, Path]) -> bool: file=path, validate_file=True, ) - else: + elif isinstance(annotations, Iterable): return BulkImportRequest.create_from_objects( client=self.client, project_id=self.uid, name=name, predictions=annotations, # type: ignore ) + else: + raise ValueError( + f'Invalid annotations given of type: {type(annotations)}') class LabelingParameterOverride(DbObject): diff --git a/mypy.ini b/mypy.ini index 161703e8e..1f440de26 100644 --- a/mypy.ini +++ b/mypy.ini @@ -3,3 +3,6 @@ ignore_missing_imports = True [mypy-ndjson.*] ignore_missing_imports = True + +[mypy-google.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index 4baf6fd31..0a5985644 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="labelbox", - version="2.4.4", + version="2.4.5", author="Labelbox", author_email="engineering@labelbox.com", description="Labelbox Python API", @@ -13,7 +13,12 @@ long_description_content_type="text/markdown", url="https://labelbox.com", packages=setuptools.find_packages(), - install_requires=["backoff==1.10.0", "ndjson==0.3.1", "requests==2.22.0"], + install_requires=[ + "backoff==1.10.0", + "ndjson==0.3.1", + "requests>=2.22.0", + "google-api-core>=1.22.1", + ], classifiers=[ 'Development Status :: 3 - Alpha', 'License :: OSI Approved :: Apache Software License', diff --git a/tests/integration/test_labeling_frontend.py b/tests/integration/test_labeling_frontend.py index 94142d926..7c1c72751 100644 --- a/tests/integration/test_labeling_frontend.py +++ b/tests/integration/test_labeling_frontend.py @@ -3,13 +3,16 @@ def test_get_labeling_frontends(client): frontends = list(client.get_labeling_frontends()) - assert len(frontends) == 1, frontends + assert len(frontends) >= 1, ( + 'Projects should have at least one frontend by default.') # Test filtering - single = list( - client.get_labeling_frontends(where=LabelingFrontend.iframe_url_path == - frontends[0].iframe_url_path)) - assert len(single) == 1, single + target_frontend = frontends[0] + filtered_frontends = client.get_labeling_frontends( + where=LabelingFrontend.iframe_url_path == + target_frontend.iframe_url_path) + for frontend in filtered_frontends: + assert target_frontend == frontend def test_labeling_frontend_connecting_to_project(project):