-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow SDK methods to accept request_timeout
through kwargs
#332
Draft
CoreyEWood
wants to merge
2
commits into
main
Choose a base branch
from
request-timeout-from-kwargs
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -223,7 +223,11 @@ def _fixup_image_query(iq: ImageQuery) -> ImageQuery: | |
iq.result.label = convert_internal_label_to_display(iq, iq.result.label) | ||
return iq | ||
|
||
def whoami(self) -> str: | ||
def _get_request_timeout(self, **kwargs): | ||
"""Extract request_timeout from kwargs or use default.""" | ||
return kwargs.get("request_timeout", DEFAULT_REQUEST_TIMEOUT) | ||
|
||
def whoami(self, **kwargs) -> str: | ||
""" | ||
Return the username (email address) associated with the current API token. | ||
|
||
|
@@ -240,7 +244,7 @@ def whoami(self) -> str: | |
:raises ApiTokenError: If the API token is invalid | ||
:raises GroundlightClientError: If there are connectivity issues with the Groundlight service | ||
""" | ||
obj = self.user_api.who_am_i(_request_timeout=DEFAULT_REQUEST_TIMEOUT) | ||
obj = self.user_api.who_am_i(_request_timeout=self._get_request_timeout(**kwargs)) | ||
return obj["email"] | ||
|
||
def _user_is_privileged(self) -> bool: | ||
|
@@ -251,7 +255,7 @@ def _user_is_privileged(self) -> bool: | |
obj = self.user_api.who_am_i() | ||
return obj["is_superuser"] | ||
|
||
def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable=redefined-builtin | ||
def get_detector(self, id: Union[str, Detector], **kwargs) -> Detector: # pylint: disable=redefined-builtin | ||
""" | ||
Get a Detector by id. | ||
|
||
|
@@ -270,12 +274,14 @@ def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable | |
# Short-circuit | ||
return id | ||
try: | ||
obj = self.detectors_api.get_detector(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT) | ||
obj = self.detectors_api.get_detector(id=id, _request_timeout=self._get_request_timeout(**kwargs)) | ||
except NotFoundException as e: | ||
raise NotFoundError(f"Detector with id '{id}' not found") from e | ||
return Detector.parse_obj(obj.to_dict()) | ||
|
||
def get_detector_by_name(self, name: str) -> Detector: | ||
# TODO should methods like this (which make direct requests instead of going through the API) allow | ||
# kwargs to be passed through? | ||
Comment on lines
282
to
+284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's one type of tricky case for allowing |
||
""" | ||
Get a Detector by name. | ||
|
||
|
@@ -291,7 +297,7 @@ def get_detector_by_name(self, name: str) -> Detector: | |
""" | ||
return self.api_client._get_detector_by_name(name) # pylint: disable=protected-access | ||
|
||
def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetectorList: | ||
def list_detectors(self, page: int = 1, page_size: int = 10, **kwargs) -> PaginatedDetectorList: | ||
""" | ||
Retrieve a paginated list of detectors associated with your account. | ||
|
||
|
@@ -312,7 +318,7 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto | |
:return: PaginatedDetectorList containing the requested page of detectors and pagination metadata | ||
""" | ||
obj = self.detectors_api.list_detectors( | ||
page=page, page_size=page_size, _request_timeout=DEFAULT_REQUEST_TIMEOUT | ||
page=page, page_size=page_size, _request_timeout=self._get_request_timeout(**kwargs) | ||
) | ||
return PaginatedDetectorList.parse_obj(obj.to_dict()) | ||
|
||
|
@@ -358,6 +364,7 @@ def create_detector( # noqa: PLR0913 | |
patience_time: Optional[float] = None, | ||
pipeline_config: Optional[str] = None, | ||
metadata: Union[dict, str, None] = None, | ||
**kwargs, | ||
) -> Detector: | ||
""" | ||
Create a new Detector with a given name and query. | ||
|
@@ -423,7 +430,9 @@ def create_detector( # noqa: PLR0913 | |
pipeline_config=pipeline_config, | ||
metadata=metadata, | ||
) | ||
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) | ||
obj = self.detectors_api.create_detector( | ||
detector_creation_input, _request_timeout=self._get_request_timeout(**kwargs) | ||
) | ||
return Detector.parse_obj(obj.to_dict()) | ||
|
||
def get_or_create_detector( # noqa: PLR0913 | ||
|
@@ -435,6 +444,7 @@ def get_or_create_detector( # noqa: PLR0913 | |
confidence_threshold: Optional[float] = None, | ||
pipeline_config: Optional[str] = None, | ||
metadata: Union[dict, str, None] = None, | ||
**kwargs, | ||
) -> Detector: | ||
""" | ||
Tries to look up the Detector by name. If a Detector with that name, query, and | ||
|
@@ -491,6 +501,7 @@ def get_or_create_detector( # noqa: PLR0913 | |
confidence_threshold=confidence_threshold, | ||
pipeline_config=pipeline_config, | ||
metadata=metadata, | ||
**kwargs, | ||
) | ||
|
||
# TODO: We may soon allow users to update the retrieved detector's fields. | ||
|
@@ -512,7 +523,7 @@ def get_or_create_detector( # noqa: PLR0913 | |
) | ||
return existing_detector | ||
|
||
def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-builtin | ||
def get_image_query(self, id: str, **kwargs) -> ImageQuery: # pylint: disable=redefined-builtin | ||
""" | ||
Get an ImageQuery by its ID. This is useful for retrieving the status and results of a | ||
previously submitted query. | ||
|
@@ -534,15 +545,15 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b | |
|
||
:return: ImageQuery object containing the query details and results | ||
""" | ||
obj = self.image_queries_api.get_image_query(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT) | ||
obj = self.image_queries_api.get_image_query(id=id, _request_timeout=self._get_request_timeout(**kwargs)) | ||
if obj.result_type == "counting" and getattr(obj.result, "label", None): | ||
obj.result.pop("label") | ||
obj.result["count"] = None | ||
iq = ImageQuery.parse_obj(obj.to_dict()) | ||
return self._fixup_image_query(iq) | ||
|
||
def list_image_queries( | ||
self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None | ||
self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None, **kwargs | ||
) -> PaginatedImageQueryList: | ||
""" | ||
List all image queries associated with your account, with pagination support. | ||
|
@@ -565,7 +576,11 @@ def list_image_queries( | |
:return: PaginatedImageQueryList containing the requested page of image queries and pagination metadata | ||
like total count and links to next/previous pages. | ||
""" | ||
params: dict[str, Any] = {"page": page, "page_size": page_size, "_request_timeout": DEFAULT_REQUEST_TIMEOUT} | ||
params: dict[str, Any] = { | ||
"page": page, | ||
"page_size": page_size, | ||
"_request_timeout": self._get_request_timeout(**kwargs), | ||
} | ||
if detector_id: | ||
params["detector_id"] = detector_id | ||
obj = self.image_queries_api.list_image_queries(**params) | ||
|
@@ -586,6 +601,7 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t | |
inspection_id: Optional[str] = None, | ||
metadata: Union[dict, str, None] = None, | ||
image_query_id: Optional[str] = None, | ||
**kwargs, | ||
) -> ImageQuery: | ||
""" | ||
Evaluates an image with Groundlight. This is the core method for getting predictions about images. | ||
|
@@ -680,7 +696,11 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t | |
|
||
image_bytesio: ByteStreamWrapper = parse_supported_image_types(image) | ||
|
||
params = {"detector_id": detector_id, "body": image_bytesio, "_request_timeout": DEFAULT_REQUEST_TIMEOUT} | ||
params = { | ||
"detector_id": detector_id, | ||
"body": image_bytesio, | ||
"_request_timeout": self._get_request_timeout(**kwargs), | ||
} | ||
|
||
if patience_time is not None: | ||
params["patience_time"] = patience_time | ||
|
@@ -732,6 +752,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments | |
wait: Optional[float] = None, | ||
metadata: Union[dict, str, None] = None, | ||
inspection_id: Optional[str] = None, | ||
**kwargs, | ||
) -> ImageQuery: | ||
""" | ||
Evaluates an image with Groundlight, waiting until an answer above the confidence threshold | ||
|
@@ -788,6 +809,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments | |
human_review=None, | ||
metadata=metadata, | ||
inspection_id=inspection_id, | ||
**kwargs, | ||
) | ||
|
||
def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals | ||
|
@@ -797,6 +819,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca | |
wait: Optional[float] = None, | ||
metadata: Union[dict, str, None] = None, | ||
inspection_id: Optional[str] = None, | ||
**kwargs, | ||
) -> ImageQuery: | ||
""" | ||
Evaluates an image with Groundlight, getting the first ML prediction without waiting | ||
|
@@ -856,6 +879,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca | |
wait=0, | ||
metadata=metadata, | ||
inspection_id=inspection_id, | ||
**kwargs, | ||
) | ||
if iq_is_answered(iq): | ||
return iq | ||
|
@@ -871,6 +895,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments | |
human_review: Optional[str] = None, | ||
metadata: Union[dict, str, None] = None, | ||
inspection_id: Optional[str] = None, | ||
**kwargs, | ||
) -> ImageQuery: | ||
""" | ||
Submit an image query asynchronously. This is equivalent to calling `submit_image_query` | ||
|
@@ -952,6 +977,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments | |
want_async=True, | ||
metadata=metadata, | ||
inspection_id=inspection_id, | ||
**kwargs, | ||
) | ||
|
||
def wait_for_confident_result( | ||
|
@@ -960,6 +986,9 @@ def wait_for_confident_result( | |
confidence_threshold: Optional[float] = None, | ||
timeout_sec: float = 30.0, | ||
) -> ImageQuery: | ||
# TODO should this method allow request_timeout to be passed in kwargs? | ||
# It's a weird case because it might make multiple requests - would the specified request_timeout be applied | ||
# to each request? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is another tricky case for accepting |
||
""" | ||
Waits for an image query result's confidence level to reach the specified confidence_threshold. | ||
Uses polling with exponential back-off to check for results. | ||
|
@@ -1092,6 +1121,7 @@ def add_label( | |
image_query: Union[ImageQuery, str], | ||
label: Union[Label, int, str], | ||
rois: Union[List[ROI], str, None] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Provide a new label (annotation) for an image query. This is used to provide ground-truth labels | ||
|
@@ -1151,7 +1181,7 @@ def add_label( | |
else None | ||
) | ||
request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests) | ||
self.labels_api.create_label(request_params) | ||
self.labels_api.create_label(request_params, _request_timeout=self._get_request_timeout(**kwargs)) | ||
|
||
def start_inspection(self) -> str: | ||
""" | ||
|
@@ -1189,7 +1219,9 @@ def stop_inspection(self, inspection_id: str) -> str: | |
""" | ||
return self.api_client.stop_inspection(inspection_id) | ||
|
||
def update_detector_confidence_threshold(self, detector: Union[str, Detector], confidence_threshold: float) -> None: | ||
def update_detector_confidence_threshold( | ||
self, detector: Union[str, Detector], confidence_threshold: float, **kwargs | ||
) -> None: | ||
""" | ||
Updates the confidence threshold for the given detector | ||
|
||
|
@@ -1203,5 +1235,7 @@ def update_detector_confidence_threshold(self, detector: Union[str, Detector], c | |
if confidence_threshold < 0 or confidence_threshold > 1: | ||
raise ValueError("confidence must be between 0 and 1") | ||
self.detectors_api.update_detector( | ||
detector, patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold) | ||
detector, | ||
patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold), | ||
_request_timeout=self._get_request_timeout(**kwargs), | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An overall question: is there a sensible way to test this for all the methods we want to enable
request_timeout
through kwargs for?