Skip to content

Commit

Permalink
Merge pull request #79 from NERSC/auth_token
Browse files Browse the repository at this point in the history
Allow an existing access token to be used with client
  • Loading branch information
cjh1 authored Aug 5, 2024
2 parents 0ab2445 + d5a2548 commit 7fb4175
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 138 deletions.
120 changes: 53 additions & 67 deletions src/sfapi_client/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(
key: Optional[Union[str, Path]] = None,
api_base_url: Optional[str] = SFAPI_BASE_URL,
token_url: Optional[str] = SFAPI_TOKEN_URL,
access_token: Optional[str] = None,
wait_interval: int = 10,
):
"""
Expand All @@ -230,6 +231,10 @@ def __init__(
:param client_id: The client ID
:param secret: The client secret
:param key: The path to the client secret file
:param api_base_url: The API base URL
:param token_url: The token URL
:param access_token: An existing access token
:return: The client instance
:rtype: AsyncClient
Expand All @@ -244,53 +249,64 @@ def __init__(
self._api_base_url = api_base_url
self._token_url = token_url
self._client_user = None
self.__oauth2_session = None
self.__http_client = None
self._api = None
self._resources = None
self._wait_interval = wait_interval
self._access_token = access_token

async def __aenter__(self):
return self

async def _oauth2_session(self):
if self._client_id is None:
raise SfApiError("No credentials have been provides")

if self.__oauth2_session is None:
# Create a new session if we haven't already
self.__oauth2_session = AsyncOAuth2Client(
client_id=self._client_id,
client_secret=self._secret,
token_endpoint_auth_method=PrivateKeyJWT(self._token_url),
grant_type="client_credentials",
token_endpoint=self._token_url,
timeout=10.0,
)
async def _http_client(self):
headers = {"accept": "application/json"}
# If we have a client_id then we need to use the OAuth2 client
if self._client_id is not None:
if self.__http_client is None:
# Create a new session if we haven't already
self.__http_client = AsyncOAuth2Client(
client_id=self._client_id,
client_secret=self._secret,
token_endpoint_auth_method=PrivateKeyJWT(self._token_url),
grant_type="client_credentials",
token_endpoint=self._token_url,
timeout=10.0,
headers=headers,
)

await self.__oauth2_session.fetch_token()
else:
# We have a session
# Make sure it's still active
await self.__oauth2_session.ensure_active_token(self.__oauth2_session.token)
await self.__http_client.fetch_token()
else:
# We have a session
# Make sure it's still active
await self.__http_client.ensure_active_token(self.__http_client.token)
# Use regular client, but add the access token if we have one
elif self.__http_client is None:
# We already have an access token
if self._access_token is not None:
headers.update({"Authorization": f"Bearer {self._access_token}"})

self.__http_client = httpx.AsyncClient(headers=headers)

return self.__oauth2_session
return self.__http_client

@property
async def token(self) -> str:
"""
Bearer token string which can be helpful for debugging through swagger UI.
"""

if self._client_id is not None:
oauth_session = await self._oauth2_session()
return oauth_session.token["access_token"]
if self._access_token is not None:
return self._access_token
elif self._client_id is not None:
client = await self._http_client()
return client.token["access_token"]

async def close(self):
"""
Release resources associated with the client instance.
"""
if self.__oauth2_session is not None:
await self.__oauth2_session.aclose()
if self.__http_client is not None:
await self.__http_client.aclose()

async def __aexit__(self, type, value, traceback):
await self.close()
Expand Down Expand Up @@ -345,29 +361,11 @@ def _read_client_secret_from_file(self, name):
stop=tenacity.stop_after_attempt(MAX_RETRY),
)
async def get(self, url: str, params: Dict[str, Any] = {}) -> httpx.Response:
if self._client_id is not None:
oauth_session = await self._oauth2_session()

r = await oauth_session.get(
f"{self._api_base_url}/{url}",
headers={
"Authorization": oauth_session.token["access_token"],
"accept": "application/json",
},
params=params,
)
# Use regular client if we are hitting an endpoint that don't need
# auth.
else:
async with httpx.AsyncClient() as client:
r = await client.get(
f"{self._api_base_url}/{url}",
headers={
"accept": "application/json",
},
params=params,
)

client = await self._http_client()
r = await client.get(
f"{self._api_base_url}/{url}",
params=params,
)
r.raise_for_status()

return r
Expand All @@ -380,14 +378,10 @@ async def get(self, url: str, params: Dict[str, Any] = {}) -> httpx.Response:
stop=tenacity.stop_after_attempt(MAX_RETRY),
)
async def post(self, url: str, data: Dict[str, Any]) -> httpx.Response:
oauth_session = await self._oauth2_session()
client = await self._http_client()

r = await oauth_session.post(
r = await client.post(
f"{self._api_base_url}/{url}",
headers={
"Authorization": oauth_session.token["access_token"],
"accept": "application/json",
},
data=data,
)
r.raise_for_status()
Expand All @@ -404,14 +398,10 @@ async def post(self, url: str, data: Dict[str, Any]) -> httpx.Response:
async def put(
self, url: str, data: Dict[str, Any] = None, files: Dict[str, Any] = None
) -> httpx.Response:
oauth_session = await self._oauth2_session()
client = await self._http_client()

r = await oauth_session.put(
r = await client.put(
f"{self._api_base_url}/{url}",
headers={
"Authorization": oauth_session.token["access_token"],
"accept": "application/json",
},
data=data,
files=files,
)
Expand All @@ -427,14 +417,10 @@ async def put(
stop=tenacity.stop_after_attempt(MAX_RETRY),
)
async def delete(self, url: str) -> httpx.Response:
oauth_session = await self._oauth2_session()
client = await self._http_client()

r = await oauth_session.delete(
r = await client.delete(
f"{self._api_base_url}/{url}",
headers={
"Authorization": oauth_session.token["access_token"],
"accept": "application/json",
},
)
r.raise_for_status()

Expand Down
2 changes: 1 addition & 1 deletion src/sfapi_client/_async/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def check_auth(method: Callable):
def wrapper(self, *args, **kwargs):
if self.client._client_id is None:
if self.client._client_id is None and self.client._access_token is None:
raise SfApiError(
f"Cannot call {self.__class__.__name__}.{method.__name__}() with an unauthenticated client." # noqa: E501
)
Expand Down
Loading

0 comments on commit 7fb4175

Please sign in to comment.