diff --git a/src/groundlight/client.py b/src/groundlight/client.py index 9a55a906..42c07bea 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -223,7 +223,11 @@ def _fixup_image_query(iq: ImageQuery) -> ImageQuery: iq.result.label = convert_internal_label_to_display(iq, iq.result.label) return iq - def whoami(self) -> str: + def _get_request_timeout(self, **kwargs): + """Extract request_timeout from kwargs or use default.""" + return kwargs.get("request_timeout", DEFAULT_REQUEST_TIMEOUT) + + def whoami(self, **kwargs) -> str: """ Return the username (email address) associated with the current API token. @@ -240,7 +244,7 @@ def whoami(self) -> str: :raises ApiTokenError: If the API token is invalid :raises GroundlightClientError: If there are connectivity issues with the Groundlight service """ - obj = self.user_api.who_am_i(_request_timeout=DEFAULT_REQUEST_TIMEOUT) + obj = self.user_api.who_am_i(_request_timeout=self._get_request_timeout(**kwargs)) return obj["email"] def _user_is_privileged(self) -> bool: @@ -251,7 +255,7 @@ def _user_is_privileged(self) -> bool: obj = self.user_api.who_am_i() return obj["is_superuser"] - def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable=redefined-builtin + def get_detector(self, id: Union[str, Detector], **kwargs) -> Detector: # pylint: disable=redefined-builtin """ Get a Detector by id. @@ -270,12 +274,14 @@ def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable # Short-circuit return id try: - obj = self.detectors_api.get_detector(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT) + obj = self.detectors_api.get_detector(id=id, _request_timeout=self._get_request_timeout(**kwargs)) except NotFoundException as e: raise NotFoundError(f"Detector with id '{id}' not found") from e return Detector.parse_obj(obj.to_dict()) def get_detector_by_name(self, name: str) -> Detector: + # TODO should methods like this (which make direct requests instead of going through the API) allow + # kwargs to be passed through? """ Get a Detector by name. @@ -291,7 +297,7 @@ def get_detector_by_name(self, name: str) -> Detector: """ return self.api_client._get_detector_by_name(name) # pylint: disable=protected-access - def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetectorList: + def list_detectors(self, page: int = 1, page_size: int = 10, **kwargs) -> PaginatedDetectorList: """ Retrieve a paginated list of detectors associated with your account. @@ -312,7 +318,7 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto :return: PaginatedDetectorList containing the requested page of detectors and pagination metadata """ obj = self.detectors_api.list_detectors( - page=page, page_size=page_size, _request_timeout=DEFAULT_REQUEST_TIMEOUT + page=page, page_size=page_size, _request_timeout=self._get_request_timeout(**kwargs) ) return PaginatedDetectorList.parse_obj(obj.to_dict()) @@ -358,6 +364,7 @@ def create_detector( # noqa: PLR0913 patience_time: Optional[float] = None, pipeline_config: Optional[str] = None, metadata: Union[dict, str, None] = None, + **kwargs, ) -> Detector: """ Create a new Detector with a given name and query. @@ -423,7 +430,9 @@ def create_detector( # noqa: PLR0913 pipeline_config=pipeline_config, metadata=metadata, ) - obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) + obj = self.detectors_api.create_detector( + detector_creation_input, _request_timeout=self._get_request_timeout(**kwargs) + ) return Detector.parse_obj(obj.to_dict()) def get_or_create_detector( # noqa: PLR0913 @@ -435,6 +444,7 @@ def get_or_create_detector( # noqa: PLR0913 confidence_threshold: Optional[float] = None, pipeline_config: Optional[str] = None, metadata: Union[dict, str, None] = None, + **kwargs, ) -> Detector: """ Tries to look up the Detector by name. If a Detector with that name, query, and @@ -491,6 +501,7 @@ def get_or_create_detector( # noqa: PLR0913 confidence_threshold=confidence_threshold, pipeline_config=pipeline_config, metadata=metadata, + **kwargs, ) # TODO: We may soon allow users to update the retrieved detector's fields. @@ -512,7 +523,7 @@ def get_or_create_detector( # noqa: PLR0913 ) return existing_detector - def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-builtin + def get_image_query(self, id: str, **kwargs) -> ImageQuery: # pylint: disable=redefined-builtin """ Get an ImageQuery by its ID. This is useful for retrieving the status and results of a previously submitted query. @@ -534,7 +545,7 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b :return: ImageQuery object containing the query details and results """ - obj = self.image_queries_api.get_image_query(id=id, _request_timeout=DEFAULT_REQUEST_TIMEOUT) + obj = self.image_queries_api.get_image_query(id=id, _request_timeout=self._get_request_timeout(**kwargs)) if obj.result_type == "counting" and getattr(obj.result, "label", None): obj.result.pop("label") obj.result["count"] = None @@ -542,7 +553,7 @@ def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-b return self._fixup_image_query(iq) def list_image_queries( - self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None + self, page: int = 1, page_size: int = 10, detector_id: Union[str, None] = None, **kwargs ) -> PaginatedImageQueryList: """ List all image queries associated with your account, with pagination support. @@ -565,7 +576,11 @@ def list_image_queries( :return: PaginatedImageQueryList containing the requested page of image queries and pagination metadata like total count and links to next/previous pages. """ - params: dict[str, Any] = {"page": page, "page_size": page_size, "_request_timeout": DEFAULT_REQUEST_TIMEOUT} + params: dict[str, Any] = { + "page": page, + "page_size": page_size, + "_request_timeout": self._get_request_timeout(**kwargs), + } if detector_id: params["detector_id"] = detector_id obj = self.image_queries_api.list_image_queries(**params) @@ -586,6 +601,7 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t inspection_id: Optional[str] = None, metadata: Union[dict, str, None] = None, image_query_id: Optional[str] = None, + **kwargs, ) -> ImageQuery: """ Evaluates an image with Groundlight. This is the core method for getting predictions about images. @@ -680,7 +696,11 @@ def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments, t image_bytesio: ByteStreamWrapper = parse_supported_image_types(image) - params = {"detector_id": detector_id, "body": image_bytesio, "_request_timeout": DEFAULT_REQUEST_TIMEOUT} + params = { + "detector_id": detector_id, + "body": image_bytesio, + "_request_timeout": self._get_request_timeout(**kwargs), + } if patience_time is not None: params["patience_time"] = patience_time @@ -732,6 +752,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments wait: Optional[float] = None, metadata: Union[dict, str, None] = None, inspection_id: Optional[str] = None, + **kwargs, ) -> ImageQuery: """ Evaluates an image with Groundlight, waiting until an answer above the confidence threshold @@ -788,6 +809,7 @@ def ask_confident( # noqa: PLR0913 # pylint: disable=too-many-arguments human_review=None, metadata=metadata, inspection_id=inspection_id, + **kwargs, ) def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals @@ -797,6 +819,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca wait: Optional[float] = None, metadata: Union[dict, str, None] = None, inspection_id: Optional[str] = None, + **kwargs, ) -> ImageQuery: """ Evaluates an image with Groundlight, getting the first ML prediction without waiting @@ -856,6 +879,7 @@ def ask_ml( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-loca wait=0, metadata=metadata, inspection_id=inspection_id, + **kwargs, ) if iq_is_answered(iq): return iq @@ -871,6 +895,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments human_review: Optional[str] = None, metadata: Union[dict, str, None] = None, inspection_id: Optional[str] = None, + **kwargs, ) -> ImageQuery: """ Submit an image query asynchronously. This is equivalent to calling `submit_image_query` @@ -952,6 +977,7 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments want_async=True, metadata=metadata, inspection_id=inspection_id, + **kwargs, ) def wait_for_confident_result( @@ -960,6 +986,9 @@ def wait_for_confident_result( confidence_threshold: Optional[float] = None, timeout_sec: float = 30.0, ) -> ImageQuery: + # TODO should this method allow request_timeout to be passed in kwargs? + # It's a weird case because it might make multiple requests - would the specified request_timeout be applied + # to each request? """ Waits for an image query result's confidence level to reach the specified confidence_threshold. Uses polling with exponential back-off to check for results. @@ -1092,6 +1121,7 @@ def add_label( image_query: Union[ImageQuery, str], label: Union[Label, int, str], rois: Union[List[ROI], str, None] = None, + **kwargs, ): """ Provide a new label (annotation) for an image query. This is used to provide ground-truth labels @@ -1151,7 +1181,7 @@ def add_label( else None ) request_params = LabelValueRequest(label=label, image_query_id=image_query_id, rois=roi_requests) - self.labels_api.create_label(request_params) + self.labels_api.create_label(request_params, _request_timeout=self._get_request_timeout(**kwargs)) def start_inspection(self) -> str: """ @@ -1189,7 +1219,9 @@ def stop_inspection(self, inspection_id: str) -> str: """ return self.api_client.stop_inspection(inspection_id) - def update_detector_confidence_threshold(self, detector: Union[str, Detector], confidence_threshold: float) -> None: + def update_detector_confidence_threshold( + self, detector: Union[str, Detector], confidence_threshold: float, **kwargs + ) -> None: """ Updates the confidence threshold for the given detector @@ -1203,5 +1235,7 @@ def update_detector_confidence_threshold(self, detector: Union[str, Detector], c if confidence_threshold < 0 or confidence_threshold > 1: raise ValueError("confidence must be between 0 and 1") self.detectors_api.update_detector( - detector, patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold) + detector, + patched_detector_request=PatchedDetectorRequest(confidence_threshold=confidence_threshold), + _request_timeout=self._get_request_timeout(**kwargs), )