Skip to content

Commit

Permalink
[FEATURE] Add retries to the internal httpx.Client used by the SDK (#…
Browse files Browse the repository at this point in the history
…5386)

# Description

This PR adds a new argument `retries` that can be used to specify the
number of times that an HTTP request performed by the internal
`httpx.Client` used by the SDK should be retried before raising an
exception.

This is useful as sometimes while using the `dataset.records.log` you
can receive a `ConnectionError` or 5xx from the server and if you retry
a few seconds later everything is fine.

**Type of change**

- Improvement

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Ben Burtenshaw <[email protected]>
Co-authored-by: burtenshaw <[email protected]>
Co-authored-by: Paco Aranda <[email protected]>
  • Loading branch information
4 people authored Sep 13, 2024
1 parent a9dd0fb commit f5ff647
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 17 deletions.
4 changes: 3 additions & 1 deletion argilla/src/argilla/_api/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ class APIClient:
def __init__(
self,
api_url: Optional[str] = DEFAULT_HTTP_CONFIG.api_url,
api_key: str = DEFAULT_HTTP_CONFIG.api_key,
api_key: Optional[str] = DEFAULT_HTTP_CONFIG.api_key,
timeout: int = DEFAULT_HTTP_CONFIG.timeout,
retries: int = DEFAULT_HTTP_CONFIG.retries,
**http_client_args,
):
if not api_url:
Expand All @@ -120,6 +121,7 @@ def __init__(

http_client_args = http_client_args or {}
http_client_args["timeout"] = timeout
http_client_args["retries"] = retries

self.http_client = create_http_client(
api_url=self.api_url, # type: ignore
Expand Down
18 changes: 10 additions & 8 deletions argilla/src/argilla/_api/_http/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@ class HTTPClientConfig:

api_url: str
api_key: str
timeout: int = None

def __post_init__(self):
self.api_url = self.api_url
self.api_key = self.api_key
self.timeout = self.timeout or 60
timeout: int = 60
retries: int = 5


def create_http_client(api_url: str, api_key: str, **client_args) -> httpx.Client:
Expand All @@ -37,5 +33,11 @@ def create_http_client(api_url: str, api_key: str, **client_args) -> httpx.Clien

headers = client_args.pop("headers", {})
headers["X-Argilla-Api-Key"] = api_key

return httpx.Client(base_url=api_url, headers=headers, **client_args)
retries = client_args.pop("retries", 0)

return httpx.Client(
base_url=api_url,
headers=headers,
transport=httpx.HTTPTransport(retries=retries),
**client_args,
)
23 changes: 16 additions & 7 deletions argilla/src/argilla/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,8 @@ class Argilla(_api.APIClient):
datasets: A collection of datasets.
users: A collection of users.
me: The current user.
"""

workspaces: "Workspaces"
datasets: "Datasets"
users: "Users"
me: "User"

# Default instance of Argilla
_default_client: Optional["Argilla"] = None

Expand All @@ -57,9 +51,24 @@ def __init__(
api_url: Optional[str] = DEFAULT_HTTP_CONFIG.api_url,
api_key: Optional[str] = DEFAULT_HTTP_CONFIG.api_key,
timeout: int = DEFAULT_HTTP_CONFIG.timeout,
retries: int = DEFAULT_HTTP_CONFIG.retries,
**http_client_args,
) -> None:
super().__init__(api_url=api_url, api_key=api_key, timeout=timeout, **http_client_args)
"""Inits the `Argilla` client.
Args:
api_url: the URL of the Argilla API. If not provided, then the value will try
to be set from `ARGILLA_API_URL` environment variable. Defaults to
`"http://localhost:6900"`.
api_key: the key to be used to authenticate in the Argilla API. If not provided,
then the value will try to be set from `ARGILLA_API_KEY` environment variable.
Defaults to `None`.
timeout: the maximum time in seconds to wait for a request to the Argilla API
to be completed before raising an exception. Defaults to `60`.
retries: the number of times to retry the HTTP connection to the Argilla API
before raising an exception. Defaults to `5`.
"""
super().__init__(api_url=api_url, api_key=api_key, timeout=timeout, retries=retries, **http_client_args)

self._set_default(self)

Expand Down
16 changes: 15 additions & 1 deletion argilla/tests/unit/api/http/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from httpx import Timeout
from unittest.mock import MagicMock, patch

import pytest
from argilla import Argilla
from httpx import Timeout


class TestHTTPClient:
Expand Down Expand Up @@ -62,3 +64,15 @@ def test_create_client_with_extra_cookies(self):
assert http_client.base_url == "http://localhost:6900"
assert http_client.headers["X-Argilla-Api-Key"] == "argilla.apikey"
assert http_client.cookies["session"] == "session_id"

@pytest.mark.parametrize("retries", [0, 1, 5, 10])
def test_create_client_with_various_retries(self, retries):
with patch("argilla._api._client.create_http_client") as mock_create_http_client:
mock_http_client = MagicMock()
mock_create_http_client.return_value = mock_http_client

Argilla(api_url="http://test.com", api_key="test_key", retries=retries)

mock_create_http_client.assert_called_once_with(
api_url="http://test.com", api_key="test_key", timeout=60, retries=retries
)

0 comments on commit f5ff647

Please sign in to comment.