Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow SDK methods to accept request_timeout through kwargs #332

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions src/groundlight/client.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An overall question: is there a sensible way to test this for all the methods we want to enable request_timeout through kwargs for?

Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ def _fixup_image_query(iq: ImageQuery) -> ImageQuery:
iq.result.label = convert_internal_label_to_display(iq, iq.result.label)
return iq

def whoami(self) -> str:
def _get_request_timeout(self, **kwargs):
"""Extract request_timeout from kwargs or use default."""
return kwargs.get("request_timeout", DEFAULT_REQUEST_TIMEOUT)

def whoami(self, **kwargs) -> str:
"""
Return the username (email address) associated with the current API token.

Expand All @@ -240,7 +244,7 @@ def whoami(self) -> str:
:raises ApiTokenError: If the API token is invalid
:raises GroundlightClientError: If there are connectivity issues with the Groundlight service
"""
obj = self.user_api.who_am_i(_request_timeout=DEFAULT_REQUEST_TIMEOUT)
obj = self.user_api.who_am_i(_request_timeout=self._get_request_timeout(**kwargs))
return obj["email"]

def _user_is_privileged(self) -> bool:
Expand All @@ -251,7 +255,7 @@ def _user_is_privileged(self) -> bool:
obj = self.user_api.who_am_i()
return obj["is_superuser"]

def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable=redefined-builtin
def get_detector(self, id: Union[str, Detector], **kwargs) -> Detector: # pylint: disable=redefined-builtin
"""
Get a Detector by id.

Expand All @@ -270,12 +274,14 @@ def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable
# Short-circuit
return id
try:
obj = self.detectors_api.get_detector(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
obj = self.detectors_api.get_detector(id=id, _request_timeout=self._get_request_timeout(**kwargs))
except NotFoundException as e:
raise NotFoundError(f"Detector with id '{id}' not found") from e
return Detector.parse_obj(obj.to_dict())

def get_detector_by_name(self, name: str) -> Detector:
# TODO should methods like this (which make direct requests instead of going through the API) allow
# kwargs to be passed through?
Comment on lines 282 to +284
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's one type of tricky case for allowing request_timeout to be specified.

"""
Get a Detector by name.

Expand All @@ -291,7 +297,7 @@ def get_detector_by_name(self, name: str) -> Detector:
"""
return self.api_client._get_detector_by_name(name) # pylint: disable=protected-access

def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetectorList:
def list_detectors(self, page: int = 1, page_size: int = 10, **kwargs) -> PaginatedDetectorList:
"""
Retrieve a paginated list of detectors associated with your account.

Expand All @@ -312,7 +318,7 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto
:return: PaginatedDetectorList containing the requested page of detectors and pagination metadata
"""
obj = self.detectors_api.list_detectors(
page=page, page_size=page_size, _request_timeout=DEFAULT_REQUEST_TIMEOUT
page=page, page_size=page_size, _request_timeout=self._get_request_timeout(**kwargs)
)
return PaginatedDetectorList.parse_obj(obj.to_dict())

Expand Down Expand Up @@ -358,6 +364,7 @@ def create_detector( # noqa: PLR0913
patience_time: Optional[float] = None,
pipeline_config: Optional[str] = None,
metadata: Union[dict, str, None] = None,
**kwargs,
) -> Detector:
"""
Create a new Detector with a given name and query.
Expand Down Expand Up @@ -423,7 +430,9 @@ def create_detector( # noqa: PLR0913
pipeline_config=pipeline_config,
metadata=metadata,
)
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
obj = self.detectors_api.create_detector(
detector_creation_input, _request_timeout=self._get_request_timeout(**kwargs)
)
return Detector.parse_obj(obj.to_dict())

def get_or_create_detector( # noqa: PLR0913
Expand All @@ -435,6 +444,7 @@ def get_or_create_detector( # noqa: PLR0913
confidence_threshold: Optional[float] = None,
pipeline_config: Optional[str] = None,
metadata: Union[dict, str, None] = None,
**kwargs,
) -> Detector:
"""
Tries to look up the Detector by name. If a Detector with that name, query, and
Expand Down Expand Up @@ -491,6 +501,7 @@ def get_or_create_detector( # noqa: PLR0913
confidence_threshold=confidence_threshold,
pipeline_config=pipeline_config,
metadata=metadata,
**kwargs,
)

# TODO: We may soon allow users to update the retrieved detector's fields.
Expand All @@ -512,7 +523,7 @@ def get_or_create_detector( # noqa: PLR0913
)
return existing_detector

def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-builtin
def get_image_query(self, id: str, **kwargs) -> ImageQuery: # pylint: disable=redefined-builtin
"""
Get an ImageQuery by its ID. This is useful for retrieving the status and results of a
previously submitted query.
Expand All @@ -534,15 +545,15 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b

:return: ImageQuery object containing the query details and results
"""
obj = self.image_queries_api.get_image_query(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
obj = self.image_queries_api.get_image_query(id=id, _request_timeout=self._get_request_timeout(**kwargs))
if obj.result_type == "counting" and getattr(obj.result, "label", None):
obj.result.pop("label")
obj.result["count"] = None
iq = ImageQuery.parse_obj(obj.to_dict())
return self._fixup_image_query(iq)

def list_image_queries(
self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None
self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None, **kwargs
) -> PaginatedImageQueryList:
"""
List all image queries associated with your account, with pagination support.
Expand All @@ -565,7 +576,11 @@ def list_image_queries(
:return: PaginatedImageQueryList containing the requested page of image queries and pagination metadata
like total count and links to next/previous pages.
"""
params: dict[str, Any] = {"page": page, "page_size": page_size, "_request_timeout": DEFAULT_REQUEST_TIMEOUT}
params: dict[str, Any] = {
"page": page,
"page_size": page_size,
"_request_timeout": self._get_request_timeout(**kwargs),
}
if detector_id:
params["detector_id"] = detector_id
obj = self.image_queries_api.list_image_queries(**params)
Expand All @@ -586,6 +601,7 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t
inspection_id: Optional[str] = None,
metadata: Union[dict, str, None] = None,
image_query_id: Optional[str] = None,
**kwargs,
) -> ImageQuery:
"""
Evaluates an image with Groundlight. This is the core method for getting predictions about images.
Expand Down Expand Up @@ -680,7 +696,11 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t

image_bytesio: ByteStreamWrapper = parse_supported_image_types(image)

params = {"detector_id": detector_id, "body": image_bytesio, "_request_timeout": DEFAULT_REQUEST_TIMEOUT}
params = {
"detector_id": detector_id,
"body": image_bytesio,
"_request_timeout": self._get_request_timeout(**kwargs),
}

if patience_time is not None:
params["patience_time"] = patience_time
Expand Down Expand Up @@ -732,6 +752,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments
wait: Optional[float] = None,
metadata: Union[dict, str, None] = None,
inspection_id: Optional[str] = None,
**kwargs,
) -> ImageQuery:
"""
Evaluates an image with Groundlight, waiting until an answer above the confidence threshold
Expand Down Expand Up @@ -788,6 +809,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments
human_review=None,
metadata=metadata,
inspection_id=inspection_id,
**kwargs,
)

def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals
Expand All @@ -797,6 +819,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca
wait: Optional[float] = None,
metadata: Union[dict, str, None] = None,
inspection_id: Optional[str] = None,
**kwargs,
) -> ImageQuery:
"""
Evaluates an image with Groundlight, getting the first ML prediction without waiting
Expand Down Expand Up @@ -856,6 +879,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca
wait=0,
metadata=metadata,
inspection_id=inspection_id,
**kwargs,
)
if iq_is_answered(iq):
return iq
Expand All @@ -871,6 +895,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments
human_review: Optional[str] = None,
metadata: Union[dict, str, None] = None,
inspection_id: Optional[str] = None,
**kwargs,
) -> ImageQuery:
"""
Submit an image query asynchronously. This is equivalent to calling `submit_image_query`
Expand Down Expand Up @@ -952,6 +977,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments
want_async=True,
metadata=metadata,
inspection_id=inspection_id,
**kwargs,
)

def wait_for_confident_result(
Expand All @@ -960,6 +986,9 @@ def wait_for_confident_result(
confidence_threshold: Optional[float] = None,
timeout_sec: float = 30.0,
) -> ImageQuery:
# TODO should this method allow request_timeout to be passed in kwargs?
# It's a weird case because it might make multiple requests - would the specified request_timeout be applied
# to each request?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another tricky case for accepting request_timeout.

"""
Waits for an image query result's confidence level to reach the specified confidence_threshold.
Uses polling with exponential back-off to check for results.
Expand Down Expand Up @@ -1092,6 +1121,7 @@ def add_label(
image_query: Union[ImageQuery, str],
label: Union[Label, int, str],
rois: Union[List[ROI], str, None] = None,
**kwargs,
):
"""
Provide a new label (annotation) for an image query. This is used to provide ground-truth labels
Expand Down Expand Up @@ -1151,7 +1181,7 @@ def add_label(
else None
)
request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests)
self.labels_api.create_label(request_params)
self.labels_api.create_label(request_params, _request_timeout=self._get_request_timeout(**kwargs))

def start_inspection(self) -> str:
"""
Expand Down Expand Up @@ -1189,7 +1219,9 @@ def stop_inspection(self, inspection_id: str) -> str:
"""
return self.api_client.stop_inspection(inspection_id)

def update_detector_confidence_threshold(self, detector: Union[str, Detector], confidence_threshold: float) -> None:
def update_detector_confidence_threshold(
self, detector: Union[str, Detector], confidence_threshold: float, **kwargs
) -> None:
"""
Updates the confidence threshold for the given detector

Expand All @@ -1203,5 +1235,7 @@ def update_detector_confidence_threshold(self, detector: Union[str, Detector], c
if confidence_threshold < 0 or confidence_threshold > 1:
raise ValueError("confidence must be between 0 and 1")
self.detectors_api.update_detector(
detector, patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold)
detector,
patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold),
_request_timeout=self._get_request_timeout(**kwargs),
)
Loading