Skip to content

Commit 044b18a

Browse files
committed
potential implementation
1 parent cb3fd95 commit 044b18a

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

src/groundlight/client.py

+44-15
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,11 @@ def _fixup_image_query(iq: ImageQuery) -> ImageQuery:
223223
iq.result.label = convert_internal_label_to_display(iq, iq.result.label)
224224
return iq
225225

226-
def whoami(self) -> str:
226+
def _get_request_timeout(self, **kwargs):
227+
"""Extract request_timeout from kwargs or use default."""
228+
return kwargs.get("request_timeout", DEFAULT_REQUEST_TIMEOUT)
229+
230+
def whoami(self, **kwargs) -> str:
227231
"""
228232
Return the username (email address) associated with the current API token.
229233
@@ -240,7 +244,7 @@ def whoami(self) -> str:
240244
:raises ApiTokenError: If the API token is invalid
241245
:raises GroundlightClientError: If there are connectivity issues with the Groundlight service
242246
"""
243-
obj = self.user_api.who_am_i(_request_timeout=DEFAULT_REQUEST_TIMEOUT)
247+
obj = self.user_api.who_am_i(_request_timeout=self._get_request_timeout(**kwargs))
244248
return obj["email"]
245249

246250
def _user_is_privileged(self) -> bool:
@@ -251,7 +255,7 @@ def _user_is_privileged(self) -> bool:
251255
obj = self.user_api.who_am_i()
252256
return obj["is_superuser"]
253257

254-
def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable=redefined-builtin
258+
def get_detector(self, id: Union[str, Detector], **kwargs) -> Detector: # pylint: disable=redefined-builtin
255259
"""
256260
Get a Detector by id.
257261
@@ -270,7 +274,7 @@ def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable
270274
# Short-circuit
271275
return id
272276
try:
273-
obj = self.detectors_api.get_detector(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
277+
obj = self.detectors_api.get_detector(id=id, _request_timeout=self._get_request_timeout(**kwargs))
274278
except NotFoundException as e:
275279
raise NotFoundError(f"Detector with id '{id}' not found") from e
276280
return Detector.parse_obj(obj.to_dict())
@@ -291,7 +295,7 @@ def get_detector_by_name(self, name: str) -> Detector:
291295
"""
292296
return self.api_client._get_detector_by_name(name) # pylint: disable=protected-access
293297

294-
def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetectorList:
298+
def list_detectors(self, page: int = 1, page_size: int = 10, **kwargs) -> PaginatedDetectorList:
295299
"""
296300
Retrieve a paginated list of detectors associated with your account.
297301
@@ -312,7 +316,7 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto
312316
:return: PaginatedDetectorList containing the requested page of detectors and pagination metadata
313317
"""
314318
obj = self.detectors_api.list_detectors(
315-
page=page, page_size=page_size, _request_timeout=DEFAULT_REQUEST_TIMEOUT
319+
page=page, page_size=page_size, _request_timeout=self._get_request_timeout(**kwargs)
316320
)
317321
return PaginatedDetectorList.parse_obj(obj.to_dict())
318322

@@ -358,6 +362,7 @@ def create_detector( # noqa: PLR0913
358362
patience_time: Optional[float] = None,
359363
pipeline_config: Optional[str] = None,
360364
metadata: Union[dict, str, None] = None,
365+
**kwargs,
361366
) -> Detector:
362367
"""
363368
Create a new Detector with a given name and query.
@@ -423,7 +428,9 @@ def create_detector( # noqa: PLR0913
423428
pipeline_config=pipeline_config,
424429
metadata=metadata,
425430
)
426-
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
431+
obj = self.detectors_api.create_detector(
432+
detector_creation_input, _request_timeout=self._get_request_timeout(**kwargs)
433+
)
427434
return Detector.parse_obj(obj.to_dict())
428435

429436
def get_or_create_detector( # noqa: PLR0913
@@ -435,6 +442,7 @@ def get_or_create_detector( # noqa: PLR0913
435442
confidence_threshold: Optional[float] = None,
436443
pipeline_config: Optional[str] = None,
437444
metadata: Union[dict, str, None] = None,
445+
**kwargs,
438446
) -> Detector:
439447
"""
440448
Tries to look up the Detector by name. If a Detector with that name, query, and
@@ -491,6 +499,7 @@ def get_or_create_detector( # noqa: PLR0913
491499
confidence_threshold=confidence_threshold,
492500
pipeline_config=pipeline_config,
493501
metadata=metadata,
502+
**kwargs,
494503
)
495504

496505
# TODO: We may soon allow users to update the retrieved detector's fields.
@@ -512,7 +521,7 @@ def get_or_create_detector( # noqa: PLR0913
512521
)
513522
return existing_detector
514523

515-
def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-builtin
524+
def get_image_query(self, id: str, **kwargs) -> ImageQuery: # pylint: disable=redefined-builtin
516525
"""
517526
Get an ImageQuery by its ID. This is useful for retrieving the status and results of a
518527
previously submitted query.
@@ -534,15 +543,15 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b
534543
535544
:return: ImageQuery object containing the query details and results
536545
"""
537-
obj = self.image_queries_api.get_image_query(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
546+
obj = self.image_queries_api.get_image_query(id=id, _request_timeout=self._get_request_timeout(**kwargs))
538547
if obj.result_type == "counting" and getattr(obj.result, "label", None):
539548
obj.result.pop("label")
540549
obj.result["count"] = None
541550
iq = ImageQuery.parse_obj(obj.to_dict())
542551
return self._fixup_image_query(iq)
543552

544553
def list_image_queries(
545-
self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None
554+
self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None, **kwargs
546555
) -> PaginatedImageQueryList:
547556
"""
548557
List all image queries associated with your account, with pagination support.
@@ -565,7 +574,11 @@ def list_image_queries(
565574
:return: PaginatedImageQueryList containing the requested page of image queries and pagination metadata
566575
like total count and links to next/previous pages.
567576
"""
568-
params: dict[str, Any] = {"page": page, "page_size": page_size, "_request_timeout": DEFAULT_REQUEST_TIMEOUT}
577+
params: dict[str, Any] = {
578+
"page": page,
579+
"page_size": page_size,
580+
"_request_timeout": self._get_request_timeout(**kwargs),
581+
}
569582
if detector_id:
570583
params["detector_id"] = detector_id
571584
obj = self.image_queries_api.list_image_queries(**params)
@@ -586,6 +599,7 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t
586599
inspection_id: Optional[str] = None,
587600
metadata: Union[dict, str, None] = None,
588601
image_query_id: Optional[str] = None,
602+
**kwargs,
589603
) -> ImageQuery:
590604
"""
591605
Evaluates an image with Groundlight. This is the core method for getting predictions about images.
@@ -680,7 +694,11 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t
680694

681695
image_bytesio: ByteStreamWrapper = parse_supported_image_types(image)
682696

683-
params = {"detector_id": detector_id, "body": image_bytesio, "_request_timeout": DEFAULT_REQUEST_TIMEOUT}
697+
params = {
698+
"detector_id": detector_id,
699+
"body": image_bytesio,
700+
"_request_timeout": self._get_request_timeout(**kwargs),
701+
}
684702

685703
if patience_time is not None:
686704
params["patience_time"] = patience_time
@@ -732,6 +750,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments
732750
wait: Optional[float] = None,
733751
metadata: Union[dict, str, None] = None,
734752
inspection_id: Optional[str] = None,
753+
**kwargs,
735754
) -> ImageQuery:
736755
"""
737756
Evaluates an image with Groundlight, waiting until an answer above the confidence threshold
@@ -788,6 +807,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments
788807
human_review=None,
789808
metadata=metadata,
790809
inspection_id=inspection_id,
810+
**kwargs,
791811
)
792812

793813
def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals
@@ -797,6 +817,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca
797817
wait: Optional[float] = None,
798818
metadata: Union[dict, str, None] = None,
799819
inspection_id: Optional[str] = None,
820+
**kwargs,
800821
) -> ImageQuery:
801822
"""
802823
Evaluates an image with Groundlight, getting the first ML prediction without waiting
@@ -856,6 +877,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca
856877
wait=0,
857878
metadata=metadata,
858879
inspection_id=inspection_id,
880+
**kwargs,
859881
)
860882
if iq_is_answered(iq):
861883
return iq
@@ -871,6 +893,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments
871893
human_review: Optional[str] = None,
872894
metadata: Union[dict, str, None] = None,
873895
inspection_id: Optional[str] = None,
896+
**kwargs,
874897
) -> ImageQuery:
875898
"""
876899
Submit an image query asynchronously. This is equivalent to calling `submit_image_query`
@@ -952,6 +975,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments
952975
want_async=True,
953976
metadata=metadata,
954977
inspection_id=inspection_id,
978+
**kwargs,
955979
)
956980

957981
def wait_for_confident_result(
@@ -1092,6 +1116,7 @@ def add_label(
10921116
image_query: Union[ImageQuery, str],
10931117
label: Union[Label, int, str],
10941118
rois: Union[List[ROI], str, None] = None,
1119+
**kwargs,
10951120
):
10961121
"""
10971122
Provide a new label (annotation) for an image query. This is used to provide ground-truth labels
@@ -1151,7 +1176,7 @@ def add_label(
11511176
else None
11521177
)
11531178
request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests)
1154-
self.labels_api.create_label(request_params)
1179+
self.labels_api.create_label(request_params, _request_timeout=self._get_request_timeout(**kwargs))
11551180

11561181
def start_inspection(self) -> str:
11571182
"""
@@ -1189,7 +1214,9 @@ def stop_inspection(self, inspection_id: str) -> str:
11891214
"""
11901215
return self.api_client.stop_inspection(inspection_id)
11911216

1192-
def update_detector_confidence_threshold(self, detector: Union[str, Detector], confidence_threshold: float) -> None:
1217+
def update_detector_confidence_threshold(
1218+
self, detector: Union[str, Detector], confidence_threshold: float, **kwargs
1219+
) -> None:
11931220
"""
11941221
Updates the confidence threshold for the given detector
11951222
@@ -1203,5 +1230,7 @@ def update_detector_confidence_threshold(self, detector: Union[str, Detector], c
12031230
if confidence_threshold < 0 or confidence_threshold > 1:
12041231
raise ValueError("confidence must be between 0 and 1")
12051232
self.detectors_api.update_detector(
1206-
detector, patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold)
1233+
detector,
1234+
patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold),
1235+
_request_timeout=self._get_request_timeout(**kwargs),
12071236
)

0 commit comments

Comments
 (0)