Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI tests #104

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/diracx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from importlib.metadata import PackageNotFoundError, version

logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s"
level=logging.WARNING, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s"
)

try:
Expand Down
5 changes: 5 additions & 0 deletions src/diracx/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

__all__ = ("jobs",)

from . import jobs
83 changes: 83 additions & 0 deletions src/diracx/api/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

__all__ = ("create_sandbox", "download_sandbox")

import hashlib
import logging
import os
import tarfile
import tempfile
from pathlib import Path

import httpx

from diracx.client.aio import DiracClient
from diracx.client.models import SandboxInfo

logger = logging.getLogger(__name__)

SANDBOX_CHECKSUM_ALGORITHM = "sha256"
SANDBOX_COMPRESSION = "bz2"


async def create_sandbox(client: DiracClient, paths: list[Path]) -> str:
"""Create a sandbox from the given paths and upload it to the storage backend.

Any paths that are directories will be added recursively.
The returned value is the PFN of the sandbox in the storage backend and can
be used to submit jobs.
"""
with tempfile.TemporaryFile(mode="w+b") as tar_fh:
with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf:
for path in paths:
logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name)
tf.add(path.resolve(), path.name, recursive=True)
tar_fh.seek(0)

hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)()
while data := tar_fh.read(512 * 1024):
hasher.update(data)
checksum = hasher.hexdigest()
tar_fh.seek(0)
logger.debug("Sandbox checksum is %s", checksum)

sandbox_info = SandboxInfo(
checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM,
checksum=checksum,
size=os.stat(tar_fh.fileno()).st_size,
format=f"tar.{SANDBOX_COMPRESSION}",
)

res = await client.jobs.initiate_sandbox_upload(sandbox_info)
if res.url:
logger.debug("Uploading sandbox for %s", res.pfn)
files = {"file": ("file", tar_fh)}
response = httpx.post(res.url, data=res.fields, files=files)
# TODO: Handle this error better
response.raise_for_status()
logger.debug(
"Sandbox uploaded for %s with status code %s",
res.pfn,
response.status_code,
)
else:
logger.debug("%s already exists in storage backend", res.pfn)
return res.pfn


async def download_sandbox(client: DiracClient, pfn: str, destination: Path):
"""Download a sandbox from the storage backend to the given destination."""
res = await client.jobs.get_sandbox_file(pfn)
logger.debug("Downloading sandbox for %s", pfn)
with tempfile.TemporaryFile(mode="w+b") as fh:
async with httpx.AsyncClient() as http_client:
response = await http_client.get(res.url)
# TODO: Handle this error better
response.raise_for_status()
async for chunk in response.aiter_bytes():
fh.write(chunk)
logger.debug("Sandbox downloaded for %s", pfn)

with tarfile.open(fileobj=fh) as tf:
tf.extractall(path=destination, filter="data")
logger.debug("Extracted %s to %s", pfn, destination)
69 changes: 46 additions & 23 deletions src/diracx/client/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datetime import datetime
import json
import requests
import logging

from pathlib import Path
from typing import Any, Dict, List, Optional, cast
Expand Down Expand Up @@ -38,6 +39,9 @@ def patch_sdk():
"""


logger = logging.getLogger(__name__)


class DiracTokenCredential(TokenCredential):
"""Tailor get_token() for our context"""

Expand All @@ -52,7 +56,7 @@ def get_token(
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
) -> AccessToken | None:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
Expand Down Expand Up @@ -98,12 +102,21 @@ def on_request(
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
self._token = self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
try:
credentials = json.loads(self._credential.location.read_text())
except Exception:
logger.warning(
"Cannot load credentials from %s", self._credential.location
)
else:
self._token = self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

if self._token:
request.http_request.headers[
"Authorization"
] = f"Bearer {self._token.token}"


class DiracClient(DiracGenerated):
Expand Down Expand Up @@ -146,7 +159,7 @@ def __aenter__(self) -> "DiracClient":

def refresh_token(
location: Path, token_endpoint: str, client_id: str, refresh_token: str
) -> AccessToken:
) -> AccessToken | None:
"""Refresh the access token using the refresh_token flow."""
from diracx.core.utils import write_credentials

Expand All @@ -159,7 +172,13 @@ def refresh_token(
},
)

if response.status_code != 200:
if response.status_code == 401:
reason = response.json()["detail"]
logger.warning("Your refresh token is not valid anymore: %s", reason)
location.unlink()
return None
elif response.status_code != 200:
# TODO: Better handle this case, retry?
raise RuntimeError(
f"An issue occured while refreshing your access token: {response.json()['detail']}"
)
Expand Down Expand Up @@ -192,24 +211,28 @@ def get_token(location: Path, token: AccessToken | None) -> AccessToken | None:
raise RuntimeError("credentials are not set")

# Load the existing credentials
if not token:
credentials = json.loads(location.read_text())
token = AccessToken(
cast(str, credentials.get("access_token")),
cast(int, credentials.get("expires_on")),
)

# We check the validity of the token
# If not valid, then return None to inform the caller that a new token
# is needed
if not is_token_valid(token):
return None

return token
try:
if not token:
credentials = json.loads(location.read_text())
token = AccessToken(
cast(str, credentials.get("access_token")),
cast(int, credentials.get("expires_on")),
)
except Exception:
logger.warning("Cannot load credentials from %s", location)
pass
else:
# We check the validity of the token
# If not valid, then return None to inform the caller that a new token
# is needed
if is_token_valid(token):
return token
return None


def is_token_valid(token: AccessToken) -> bool:
"""Condition to get a new token"""
# TODO: Should we check against the userinfo endpoint?
return (
datetime.utcfromtimestamp(token.expires_on) - datetime.utcnow()
).total_seconds() > 300
27 changes: 20 additions & 7 deletions src/diracx/client/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
import json
import logging
from types import TracebackType
from pathlib import Path
from typing import Any, List, Optional
Expand All @@ -24,6 +25,8 @@
"DiracClient",
] # Add all objects you want publicly available to users at this package level

logger = logging.getLogger(__name__)


def patch_sdk():
"""Do not remove from this file.
Expand All @@ -48,7 +51,7 @@ async def get_token(
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
) -> AccessToken | None:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
Expand Down Expand Up @@ -104,19 +107,29 @@ async def on_request(
credentials: dict[str, Any]

try:
# TODO: Use httpx and await this call
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
self._token = await self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
try:
credentials = json.loads(self._credential.location.read_text())
except Exception:
logger.warning(
"Cannot load credentials from %s", self._credential.location
)
else:
self._token = await self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

if self._token:
request.http_request.headers[
"Authorization"
] = f"Bearer {self._token.token}"


class DiracClient(DiracGenerated):
Expand Down
Loading