diff --git a/src/sfapi_client/_async/client.py b/src/sfapi_client/_async/client.py index c8debb9..471aaf6 100644 --- a/src/sfapi_client/_async/client.py +++ b/src/sfapi_client/_async/client.py @@ -215,6 +215,7 @@ def __init__( key: Optional[Union[str, Path]] = None, api_base_url: Optional[str] = SFAPI_BASE_URL, token_url: Optional[str] = SFAPI_TOKEN_URL, + access_token: Optional[str] = None, wait_interval: int = 10, ): """ @@ -230,6 +231,10 @@ def __init__( :param client_id: The client ID :param secret: The client secret + :param key: The path to the client secret file + :param api_base_url: The API base URL + :param token_url: The token URL + :param access_token: An existing access token :return: The client instance :rtype: AsyncClient @@ -244,36 +249,45 @@ def __init__( self._api_base_url = api_base_url self._token_url = token_url self._client_user = None - self.__oauth2_session = None + self.__http_client = None self._api = None self._resources = None self._wait_interval = wait_interval + self._access_token = access_token async def __aenter__(self): return self - async def _oauth2_session(self): - if self._client_id is None: - raise SfApiError("No credentials have been provides") - - if self.__oauth2_session is None: - # Create a new session if we haven't already - self.__oauth2_session = AsyncOAuth2Client( - client_id=self._client_id, - client_secret=self._secret, - token_endpoint_auth_method=PrivateKeyJWT(self._token_url), - grant_type="client_credentials", - token_endpoint=self._token_url, - timeout=10.0, - ) + async def _http_client(self): + headers = {"accept": "application/json"} + # If we have a client_id then we need to use the OAuth2 client + if self._client_id is not None: + if self.__http_client is None: + # Create a new session if we haven't already + self.__http_client = AsyncOAuth2Client( + client_id=self._client_id, + client_secret=self._secret, + token_endpoint_auth_method=PrivateKeyJWT(self._token_url), + grant_type="client_credentials", + token_endpoint=self._token_url, + timeout=10.0, + headers=headers, + ) - await self.__oauth2_session.fetch_token() - else: - # We have a session - # Make sure it's still active - await self.__oauth2_session.ensure_active_token(self.__oauth2_session.token) + await self.__http_client.fetch_token() + else: + # We have a session + # Make sure it's still active + await self.__http_client.ensure_active_token(self.__http_client.token) + # Use regular client, but add the access token if we have one + elif self.__http_client is None: + # We already have an access token + if self._access_token is not None: + headers.update({"Authorization": f"Bearer {self._access_token}"}) + + self.__http_client = httpx.AsyncClient(headers=headers) - return self.__oauth2_session + return self.__http_client @property async def token(self) -> str: @@ -281,16 +295,18 @@ async def token(self) -> str: Bearer token string which can be helpful for debugging through swagger UI. """ - if self._client_id is not None: - oauth_session = await self._oauth2_session() - return oauth_session.token["access_token"] + if self._access_token is not None: + return self._access_token + elif self._client_id is not None: + client = await self._http_client() + return client.token["access_token"] async def close(self): """ Release resources associated with the client instance. """ - if self.__oauth2_session is not None: - await self.__oauth2_session.aclose() + if self.__http_client is not None: + await self.__http_client.aclose() async def __aexit__(self, type, value, traceback): await self.close() @@ -345,29 +361,11 @@ def _read_client_secret_from_file(self, name): stop=tenacity.stop_after_attempt(MAX_RETRY), ) async def get(self, url: str, params: Dict[str, Any] = {}) -> httpx.Response: - if self._client_id is not None: - oauth_session = await self._oauth2_session() - - r = await oauth_session.get( - f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, - params=params, - ) - # Use regular client if we are hitting an endpoint that don't need - # auth. - else: - async with httpx.AsyncClient() as client: - r = await client.get( - f"{self._api_base_url}/{url}", - headers={ - "accept": "application/json", - }, - params=params, - ) - + client = await self._http_client() + r = await client.get( + f"{self._api_base_url}/{url}", + params=params, + ) r.raise_for_status() return r @@ -380,14 +378,10 @@ async def get(self, url: str, params: Dict[str, Any] = {}) -> httpx.Response: stop=tenacity.stop_after_attempt(MAX_RETRY), ) async def post(self, url: str, data: Dict[str, Any]) -> httpx.Response: - oauth_session = await self._oauth2_session() + client = await self._http_client() - r = await oauth_session.post( + r = await client.post( f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, data=data, ) r.raise_for_status() @@ -404,14 +398,10 @@ async def post(self, url: str, data: Dict[str, Any]) -> httpx.Response: async def put( self, url: str, data: Dict[str, Any] = None, files: Dict[str, Any] = None ) -> httpx.Response: - oauth_session = await self._oauth2_session() + client = await self._http_client() - r = await oauth_session.put( + r = await client.put( f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, data=data, files=files, ) @@ -427,14 +417,10 @@ async def put( stop=tenacity.stop_after_attempt(MAX_RETRY), ) async def delete(self, url: str) -> httpx.Response: - oauth_session = await self._oauth2_session() + client = await self._http_client() - r = await oauth_session.delete( + r = await client.delete( f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, ) r.raise_for_status() diff --git a/src/sfapi_client/_async/compute.py b/src/sfapi_client/_async/compute.py index a320336..8a476c5 100644 --- a/src/sfapi_client/_async/compute.py +++ b/src/sfapi_client/_async/compute.py @@ -23,7 +23,7 @@ def check_auth(method: Callable): def wrapper(self, *args, **kwargs): - if self.client._client_id is None: + if self.client._client_id is None and self.client._access_token is None: raise SfApiError( f"Cannot call {self.__class__.__name__}.{method.__name__}() with an unauthenticated client." # noqa: E501 ) diff --git a/src/sfapi_client/_sync/client.py b/src/sfapi_client/_sync/client.py index cc6aae1..fb577cb 100644 --- a/src/sfapi_client/_sync/client.py +++ b/src/sfapi_client/_sync/client.py @@ -215,6 +215,7 @@ def __init__( key: Optional[Union[str, Path]] = None, api_base_url: Optional[str] = SFAPI_BASE_URL, token_url: Optional[str] = SFAPI_TOKEN_URL, + access_token: Optional[str] = None, wait_interval: int = 10, ): """ @@ -230,6 +231,10 @@ def __init__( :param client_id: The client ID :param secret: The client secret + :param key: The path to the client secret file + :param api_base_url: The API base URL + :param token_url: The token URL + :param access_token: An existing access token :return: The client instance :rtype: Client @@ -244,36 +249,45 @@ def __init__( self._api_base_url = api_base_url self._token_url = token_url self._client_user = None - self.__oauth2_session = None + self.__http_client = None self._api = None self._resources = None self._wait_interval = wait_interval + self._access_token = access_token def __enter__(self): return self - def _oauth2_session(self): - if self._client_id is None: - raise SfApiError("No credentials have been provides") - - if self.__oauth2_session is None: - # Create a new session if we haven't already - self.__oauth2_session = OAuth2Client( - client_id=self._client_id, - client_secret=self._secret, - token_endpoint_auth_method=PrivateKeyJWT(self._token_url), - grant_type="client_credentials", - token_endpoint=self._token_url, - timeout=10.0, - ) + def _http_client(self): + headers = {"accept": "application/json"} + # If we have a client_id then we need to use OAuth2 client + if self._client_id is not None: + if self.__http_client is None: + # Create a new session if we haven't already + self.__http_client = OAuth2Client( + client_id=self._client_id, + client_secret=self._secret, + token_endpoint_auth_method=PrivateKeyJWT(self._token_url), + grant_type="client_credentials", + token_endpoint=self._token_url, + timeout=10.0, + headers=headers, + ) - self.__oauth2_session.fetch_token() - else: - # We have a session - # Make sure it's still active - self.__oauth2_session.ensure_active_token(self.__oauth2_session.token) + self.__http_client.fetch_token() + else: + # We have a session + # Make sure it's still active + self.__http_client.ensure_active_token(self.__http_client.token) + # Use regular client, but add the access token if we have one + elif self.__http_client is None: + # We already have an access token + if self._access_token is not None: + headers.update({"Authorization": f"Bearer {self._access_token}"}) + + self.__http_client = httpx.Client(headers=headers) - return self.__oauth2_session + return self.__http_client @property def token(self) -> str: @@ -281,16 +295,18 @@ def token(self) -> str: Bearer token string which can be helpful for debugging through swagger UI. """ - if self._client_id is not None: - oauth_session = self._oauth2_session() - return oauth_session.token["access_token"] + if self._access_token is not None: + return self._access_token + elif self._client_id is not None: + client = self._http_client() + return client.token["access_token"] def close(self): """ Release resources associated with the client instance. """ - if self.__oauth2_session is not None: - self.__oauth2_session.close() + if self.__http_client is not None: + self.__http_client.close() def __exit__(self, type, value, traceback): self.close() @@ -345,29 +361,11 @@ def _read_client_secret_from_file(self, name): stop=tenacity.stop_after_attempt(MAX_RETRY), ) def get(self, url: str, params: Dict[str, Any] = {}) -> httpx.Response: - if self._client_id is not None: - oauth_session = self._oauth2_session() - - r = oauth_session.get( - f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, - params=params, - ) - # Use regular client if we are hitting an endpoint that don't need - # auth. - else: - with httpx.Client() as client: - r = client.get( - f"{self._api_base_url}/{url}", - headers={ - "accept": "application/json", - }, - params=params, - ) - + client = self._http_client() + r = client.get( + f"{self._api_base_url}/{url}", + params=params, + ) r.raise_for_status() return r @@ -380,14 +378,10 @@ def get(self, url: str, params: Dict[str, Any] = {}) -> httpx.Response: stop=tenacity.stop_after_attempt(MAX_RETRY), ) def post(self, url: str, data: Dict[str, Any]) -> httpx.Response: - oauth_session = self._oauth2_session() + client = self._http_client() - r = oauth_session.post( + r = client.post( f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, data=data, ) r.raise_for_status() @@ -404,14 +398,10 @@ def post(self, url: str, data: Dict[str, Any]) -> httpx.Response: def put( self, url: str, data: Dict[str, Any] = None, files: Dict[str, Any] = None ) -> httpx.Response: - oauth_session = self._oauth2_session() + client = self._http_client() - r = oauth_session.put( + r = client.put( f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, data=data, files=files, ) @@ -427,14 +417,10 @@ def put( stop=tenacity.stop_after_attempt(MAX_RETRY), ) def delete(self, url: str) -> httpx.Response: - oauth_session = self._oauth2_session() + client = self._http_client() - r = oauth_session.delete( + r = client.delete( f"{self._api_base_url}/{url}", - headers={ - "Authorization": oauth_session.token["access_token"], - "accept": "application/json", - }, ) r.raise_for_status() diff --git a/src/sfapi_client/_sync/compute.py b/src/sfapi_client/_sync/compute.py index 1594aca..0c7ca0b 100644 --- a/src/sfapi_client/_sync/compute.py +++ b/src/sfapi_client/_sync/compute.py @@ -23,7 +23,7 @@ def check_auth(method: Callable): def wrapper(self, *args, **kwargs): - if self.client._client_id is None: + if self.client._client_id is None and self.client._access_token is None: raise SfApiError( f"Cannot call {self.__class__.__name__}.{method.__name__}() with an unauthenticated client." # noqa: E501 ) diff --git a/tests/conftest.py b/tests/conftest.py index daba646..ae78c77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,7 @@ class Settings(BaseSettings): TEST_GROUP: Optional[str] = None DEV_API_URL: str = "https://api-dev.nersc.gov/api/v1.2" DEV_TOKEN_URL: str = "https://oidc-dev.nersc.gov/c2id/token" + ACCESS_TOKEN: Optional[str] = None model_config = ConfigDict(case_sensitive=True, env_file=".env") @@ -168,3 +169,8 @@ def async_authenticated_client(api_base_url, token_url, client_id, client_secret client_id=client_id, secret=client_secret, ) + + +@pytest.fixture +def access_token(): + return settings.ACCESS_TOKEN diff --git a/tests/test_client.py b/tests/test_client.py index 64883a3..46e0350 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,6 @@ import pytest -from sfapi_client import SfApiError +from sfapi_client import SfApiError, Client @pytest.mark.public @@ -15,3 +15,9 @@ def test_no_creds_auth_required(unauthenticated_client, test_machine): machine = client.compute(test_machine) with pytest.raises(SfApiError): machine.jobs() + + +def test_access_token(api_base_url, access_token, test_machine, test_username): + with Client(api_base_url=api_base_url, access_token=access_token) as client: + machine = client.compute(test_machine) + machine.jobs(user=test_username) diff --git a/tests/test_client_async.py b/tests/test_client_async.py index 2c44765..682149b 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -1,6 +1,6 @@ import pytest -from sfapi_client import SfApiError +from sfapi_client import SfApiError, AsyncClient @pytest.mark.public @@ -17,3 +17,12 @@ async def test_no_creds_auth_required(async_unauthenticated_client, test_machine machine = await client.compute(test_machine) with pytest.raises(SfApiError): await machine.jobs() + + +@pytest.mark.asyncio +async def test_access_token(api_base_url, access_token, test_machine, test_username): + async with AsyncClient( + api_base_url=api_base_url, access_token=access_token + ) as client: + machine = await client.compute(test_machine) + await machine.jobs(user=test_username)