diff --git a/.github/workflows/extensions.yml b/.github/workflows/extensions.yml index 19eb122d..5645c2ef 100644 --- a/.github/workflows/extensions.yml +++ b/.github/workflows/extensions.yml @@ -27,11 +27,13 @@ jobs: - package: "./extensions/gubbins/gubbins-core" dependencies: "./extensions/gubbins/gubbins-testing ./diracx-testing ./diracx-core" - package: "./extensions/gubbins/gubbins-db" - dependencies: "./extensions/gubbins/gubbins-testing ./extensions/gubbins/gubbins-core ./diracx-testing ./diracx-db ./diracx-core " + dependencies: "./extensions/gubbins/gubbins-testing ./extensions/gubbins/gubbins-core ./diracx-testing ./diracx-db ./diracx-core" + - package: "./extensions/gubbins/gubbins-logic" + dependencies: "./extensions/gubbins/gubbins-db ./extensions/gubbins/gubbins-core ./diracx-db ./diracx-core ./diracx-logic" - package: "./extensions/gubbins/gubbins-routers" - dependencies: "./extensions/gubbins/gubbins-testing ./extensions/gubbins/gubbins-db ./extensions/gubbins/gubbins-core ./diracx-testing ./diracx-db ./diracx-core ./diracx-routers" + dependencies: "./extensions/gubbins/gubbins-testing ./extensions/gubbins/gubbins-db ./extensions/gubbins/gubbins-logic ./extensions/gubbins/gubbins-core ./diracx-testing ./diracx-db ./diracx-logic ./diracx-core ./diracx-routers" - package: "./extensions/gubbins/gubbins-client" - dependencies: "./extensions/gubbins/gubbins-testing ./diracx-testing ./extensions/gubbins/gubbins-client ./extensions/gubbins/gubbins-core ./diracx-client ./diracx-core " + dependencies: "./extensions/gubbins/gubbins-testing ./diracx-testing ./extensions/gubbins/gubbins-client ./extensions/gubbins/gubbins-core ./diracx-client ./diracx-core" - package: "./extensions/gubbins/gubbins-cli" dependencies: "./extensions/gubbins/gubbins-testing ./extensions/gubbins/gubbins-client ./extensions/gubbins/gubbins-core ./diracx-testing ./diracx-cli ./diracx-client ./diracx-core ./diracx-api" steps: @@ -58,6 +60,7 @@ jobs: run: | mypy ${{ matrix.package }}/src - name: Run pytest + if: ${{ matrix.package != './extensions/gubbins/gubbins-logic' }} run: | cd ${{ matrix.package }} pip install .[testing] @@ -147,7 +150,7 @@ jobs: outputs: type=docker,dest=/tmp/gubbins_services_image.tar build-args: | EXTRA_PACKAGES_TO_INSTALL=git+https://github.com/DIRACGrid/DIRAC.git@integration - EXTENSION_CUSTOM_SOURCES_TO_INSTALL=/bindmount/gubbins_db*.whl,/bindmount/gubbins_routers*.whl,/bindmount/gubbins_client*.whl + EXTENSION_CUSTOM_SOURCES_TO_INSTALL=/bindmount/gubbins_db*.whl,/bindmount/gubbins_logic*.whl,/bindmount/gubbins_routers*.whl,/bindmount/gubbins_client*.whl - name: Build and export client uses: docker/build-push-action@v6 with: @@ -185,7 +188,7 @@ jobs: run: | pip install pytest-github-actions-annotate-failures pip install git+https://github.com/DIRACGrid/DIRAC.git@integration - pip install ./diracx-core/[testing] ./diracx-api/[testing] ./diracx-cli/[testing] ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-db/[testing] ./diracx-testing/[testing] ./extensions/gubbins/gubbins-testing[testing] ./extensions/gubbins/gubbins-db[testing] ./extensions/gubbins/gubbins-routers/[testing] ./extensions/gubbins/gubbins-client/[testing] ./extensions/gubbins/gubbins-cli/[testing] ./extensions/gubbins/gubbins-core/[testing] + pip install ./diracx-core/[testing] ./diracx-api/[testing] ./diracx-cli/[testing] ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-logic/[testing] ./diracx-db/[testing] ./diracx-testing/[testing] ./extensions/gubbins/gubbins-testing[testing] ./extensions/gubbins/gubbins-db[testing] ./extensions/gubbins/gubbins-logic/[testing] ./extensions/gubbins/gubbins-routers/[testing] ./extensions/gubbins/gubbins-client/[testing] ./extensions/gubbins/gubbins-cli/[testing] ./extensions/gubbins/gubbins-core/[testing] - name: Start demo run: | git clone https://github.com/DIRACGrid/diracx-charts.git ../diracx-charts @@ -274,7 +277,7 @@ jobs: run: | micromamba install -c conda-forge nodejs pre-commit pip install git+https://github.com/DIRACGrid/DIRAC.git@integration - pip install ./diracx-core/[testing] ./diracx-api/[testing] ./diracx-cli/[testing] ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-db/[testing] ./diracx-testing/[testing] ./extensions/gubbins/gubbins-testing[testing] ./extensions/gubbins/gubbins-db[testing] ./extensions/gubbins/gubbins-routers/[testing] ./extensions/gubbins/gubbins-testing/[testing] -e ./extensions/gubbins/gubbins-client/[testing] ./extensions/gubbins/gubbins-core/[testing] + pip install ./diracx-core/[testing] ./diracx-api/[testing] ./diracx-cli/[testing] ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-logic/[testing] ./diracx-logic/[testing] ./diracx-db/[testing] ./diracx-testing/[testing] ./extensions/gubbins/gubbins-testing[testing] ./extensions/gubbins/gubbins-db[testing] ./extensions/gubbins/gubbins-logic[testing] ./extensions/gubbins/gubbins-routers/[testing] ./extensions/gubbins/gubbins-testing/[testing] -e ./extensions/gubbins/gubbins-client/[testing] ./extensions/gubbins/gubbins-core/[testing] npm install -g autorest - name: Run autorest run: | diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 32d2f399..3942e7e3 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -29,7 +29,6 @@ jobs: run: | find -name '*.sh' -print0 | xargs -0 -n1 shellcheck --exclude=SC1090,SC1091 --external-source - unittest: name: Unit test - ${{ matrix.package }} runs-on: ubuntu-latest @@ -41,8 +40,10 @@ jobs: dependencies: "./diracx-testing" - package: "./diracx-db" dependencies: "./diracx-testing ./diracx-core" + - package: "./diracx-logic" + dependencies: "./diracx-core ./diracx-db" - package: "./diracx-routers" - dependencies: "./diracx-testing ./diracx-core ./diracx-db" + dependencies: "./diracx-testing ./diracx-core ./diracx-db ./diracx-logic" - package: "./diracx-client" dependencies: "./diracx-testing ./diracx-core" - package: "./diracx-api" @@ -68,6 +69,7 @@ jobs: pip install git+https://github.com/DIRACGrid/DIRAC.git@integration pip install ${{ matrix.dependencies }} - name: Run pytest + if: ${{ matrix.package != './diracx-logic' }} run: | cd ${{ matrix.package }} pip install .[testing] @@ -93,7 +95,7 @@ jobs: run: | pip install pytest-github-actions-annotate-failures pip install git+https://github.com/DIRACGrid/DIRAC.git@integration - pip install ./diracx-core/[testing] ./diracx-api/[testing] ./diracx-cli/[testing] ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-db/[testing] ./diracx-testing/ + pip install ./diracx-core/[testing] ./diracx-api/[testing] ./diracx-cli/[testing] ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-logic/[testing] ./diracx-db/[testing] ./diracx-testing/ - name: Start demo run: | git clone https://github.com/DIRACGrid/diracx-charts.git ../diracx-charts @@ -162,7 +164,7 @@ jobs: run: | micromamba install -c conda-forge nodejs pre-commit pip install git+https://github.com/DIRACGrid/DIRAC.git@integration - pip install ./diracx-core/ ./diracx-api/ ./diracx-cli/ -e ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-db/ ./diracx-testing/ + pip install ./diracx-core/ ./diracx-api/ ./diracx-cli/ -e ./diracx-client/[testing] ./diracx-routers/[testing] ./diracx-logic/[testing] ./diracx-db/ ./diracx-testing/ npm install -g autorest - name: Run autorest run: | diff --git a/diracx-cli/src/diracx/cli/internal/legacy.py b/diracx-cli/src/diracx/cli/internal/legacy.py index a5e07acb..dbed7dbb 100644 --- a/diracx-cli/src/diracx/cli/internal/legacy.py +++ b/diracx-cli/src/diracx/cli/internal/legacy.py @@ -230,6 +230,7 @@ def generate_helm_values( helm_values["diracx"] = diracx_config diracx_config["hostname"] = diracx_hostname + diracx_settings["DIRACX_SERVICE_AUTH_TOKEN_ISSUER"] = diracx_url diracx_settings["DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS"] = json.dumps( [ urljoin(diracx_url, "api/docs/oauth2-redirect"), diff --git a/diracx-core/pyproject.toml b/diracx-core/pyproject.toml index 5b058978..4955e2ac 100644 --- a/diracx-core/pyproject.toml +++ b/diracx-core/pyproject.toml @@ -33,6 +33,7 @@ testing = [ ] types = [ "botocore-stubs", + "types-aiobotocore[essential]", "types-aiobotocore-s3", "types-cachetools", "types-PyYAML", diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 04b70192..bea042e7 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -28,6 +28,18 @@ class ExpiredFlowError(AuthorizationError): """Used only for the Device Flow when the polling is expired.""" +class IAMServerError(DiracError): + """Used whenever we encounter a server problem with the IAM server.""" + + +class IAMClientError(DiracError): + """Used whenever we encounter a client problem with the IAM server.""" + + +class InvalidCredentialsError(DiracError): + """Used whenever the credentials are invalid.""" + + class ConfigurationError(DiracError): """Used whenever we encounter a problem with the configuration.""" @@ -40,6 +52,12 @@ class InvalidQueryError(DiracError): """It was not possible to build a valid database query from the given input.""" +class TokenNotFoundError(Exception): + def __init__(self, jti: str, detail: str | None = None): + self.jti: str = jti + super().__init__(f"Token {jti} not found" + (" ({detail})" if detail else "")) + + class JobNotFoundError(Exception): def __init__(self, job_id: int, detail: str | None = None): self.job_id: int = job_id @@ -66,6 +84,16 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): ) +class SandboxAlreadyInsertedError(Exception): + def __init__(self, pfn: str, se_name: str, detail: str | None = None): + self.pfn: str = pfn + self.se_name: str = se_name + super().__init__( + f"Sandbox with {pfn} and {se_name} already inserted" + + (" ({detail})" if detail else "") + ) + + class JobError(Exception): def __init__(self, job_id, detail: str | None = None): self.job_id: int = job_id diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index ca5598d7..b65f66c4 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -1,3 +1,8 @@ +"""Models are used to define the data structure of the requests and responses +for the DiracX API. They are shared between the client components (cli, api) and +services components (db, logic, routers). +""" + from __future__ import annotations from datetime import datetime @@ -46,12 +51,25 @@ class SortSpec(TypedDict): direction: SortDirection -class TokenResponse(BaseModel): - # Based on RFC 6749 - access_token: str - expires_in: int - token_type: str = "Bearer" # noqa: S105 - refresh_token: str | None = None +class InsertedJob(TypedDict): + JobID: int + Status: str + MinorStatus: str + TimeStamp: datetime + + +class JobSummaryParams(BaseModel): + grouping: list[str] + search: list[SearchSpec] = [] + # TODO: Add more validation + + +class JobSearchParams(BaseModel): + parameters: list[str] | None = None + search: list[SearchSpec] = [] + sort: list[SortSpec] = [] + distinct: bool = False + # TODO: Add more validation class JobStatus(StrEnum): @@ -77,6 +95,15 @@ class JobMinorStatus(StrEnum): RESCHEDULED = "Job Rescheduled" +class JobLoggingRecord(BaseModel): + job_id: int + status: JobStatus + minor_status: str + application_status: str + date: datetime + source: str + + class JobStatusUpdate(BaseModel): Status: JobStatus | None = None MinorStatus: str | None = None @@ -136,3 +163,93 @@ class SandboxInfo(BaseModel): class SandboxType(StrEnum): Input = "Input" Output = "Output" + + +class SandboxDownloadResponse(BaseModel): + url: str + expires_in: int + + +class SandboxUploadResponse(BaseModel): + pfn: str + url: str | None = None + fields: dict[str, str] = {} + + +class GrantType(StrEnum): + """Grant types for OAuth2.""" + + authorization_code = "authorization_code" + device_code = "urn:ietf:params:oauth:grant-type:device_code" + refresh_token = "refresh_token" # noqa: S105 # False positive of Bandit about hard coded password + + +class InitiateDeviceFlowResponse(TypedDict): + """Response for the device flow initiation.""" + + user_code: str + device_code: str + verification_uri_complete: str + verification_uri: str + expires_in: int + + +class OpenIDConfiguration(TypedDict): + issuer: str + token_endpoint: str + userinfo_endpoint: str + authorization_endpoint: str + device_authorization_endpoint: str + grant_types_supported: list[str] + scopes_supported: list[str] + response_types_supported: list[str] + token_endpoint_auth_signing_alg_values_supported: list[str] + token_endpoint_auth_methods_supported: list[str] + code_challenge_methods_supported: list[str] + + +class TokenPayload(TypedDict): + jti: str + exp: datetime + dirac_policies: dict + + +class TokenResponse(BaseModel): + # Based on RFC 6749 + access_token: str + expires_in: int + token_type: str = "Bearer" # noqa: S105 + refresh_token: str | None = None + + +class AccessTokenPayload(TokenPayload): + sub: str + vo: str + iss: str + dirac_properties: list[str] + preferred_username: str + dirac_group: str + + +class RefreshTokenPayload(TokenPayload): + legacy_exchange: bool + + +class SupportInfo(TypedDict): + message: str + webpage: str | None + email: str | None + + +class GroupInfo(TypedDict): + properties: list[str] + + +class VOInfo(TypedDict): + groups: dict[str, GroupInfo] + support: SupportInfo + default_group: str + + +class Metadata(TypedDict): + virtual_organizations: dict[str, VOInfo] diff --git a/diracx-core/src/diracx/core/settings.py b/diracx-core/src/diracx/core/settings.py index 819633f5..99059023 100644 --- a/diracx-core/src/diracx/core/settings.py +++ b/diracx-core/src/diracx/core/settings.py @@ -1,5 +1,8 @@ from __future__ import annotations +from diracx.core.properties import SecurityProperty +from diracx.core.s3 import s3_bucket_exists + __all__ = ( "SqlalchemyDsn", "LocalFileUrl", @@ -9,20 +12,28 @@ import contextlib from collections.abc import AsyncIterator from pathlib import Path -from typing import Annotated, Any, Self, TypeVar +from typing import TYPE_CHECKING, Annotated, Any, Self, TypeVar +from aiobotocore.session import get_session from authlib.jose import JsonWebKey +from botocore.config import Config +from botocore.errorfactory import ClientError from cryptography.fernet import Fernet from pydantic import ( AnyUrl, BeforeValidator, + Field, FileUrl, + PrivateAttr, SecretStr, TypeAdapter, UrlConstraints, ) from pydantic_settings import BaseSettings, SettingsConfigDict +if TYPE_CHECKING: + from types_aiobotocore_s3.client import S3Client + T = TypeVar("T") @@ -102,3 +113,73 @@ class DevelopmentSettings(ServiceSettingsBase): # When then to true (only for demo/CI), crash if an access policy isn't # called crash_on_missed_access_policy: bool = False + + @classmethod + def create(cls) -> Self: + return cls() + + +class AuthSettings(ServiceSettingsBase): + """Settings for the authentication service.""" + + model_config = SettingsConfigDict(env_prefix="DIRACX_SERVICE_AUTH_") + + dirac_client_id: str = "myDIRACClientID" + # TODO: This should be taken dynamically + # ["http://pclhcb211:8000/docs/oauth2-redirect"] + allowed_redirects: list[str] = [] + device_flow_expiration_seconds: int = 600 + authorization_flow_expiration_seconds: int = 300 + + # State key is used to encrypt/decrypt the state dict passed to the IAM + state_key: FernetKey + + token_issuer: str + token_key: TokenSigningKey + token_algorithm: str = "RS256" # noqa: S105 + access_token_expire_minutes: int = 20 + refresh_token_expire_minutes: int = 60 + + available_properties: set[SecurityProperty] = Field( + default_factory=SecurityProperty.available_properties + ) + + +class SandboxStoreSettings(ServiceSettingsBase): + """Settings for the sandbox store.""" + + model_config = SettingsConfigDict(env_prefix="DIRACX_SANDBOX_STORE_") + + bucket_name: str + s3_client_kwargs: dict[str, str] + auto_create_bucket: bool = False + url_validity_seconds: int = 5 * 60 + se_name: str = "SandboxSE" + _client: S3Client = PrivateAttr() + + @contextlib.asynccontextmanager + async def lifetime_function(self) -> AsyncIterator[None]: + async with get_session().create_client( + "s3", + **self.s3_client_kwargs, + config=Config(signature_version="v4"), + ) as self._client: # type: ignore + if not await s3_bucket_exists(self._client, self.bucket_name): + if not self.auto_create_bucket: + raise ValueError( + f"Bucket {self.bucket_name} does not exist and auto_create_bucket is disabled" + ) + try: + await self._client.create_bucket(Bucket=self.bucket_name) + except ClientError as e: + raise ValueError( + f"Failed to create bucket {self.bucket_name}" + ) from e + + yield + + @property + def s3_client(self) -> S3Client: + if self._client is None: + raise RuntimeError("S3 client accessed before lifetime function") + return self._client diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index fc4ec487..83936499 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -13,9 +13,7 @@ classifiers = [ "Topic :: System :: Distributed Computing", ] dependencies = [ - "dirac", "diracx-core", - "fastapi", "opensearch-py[async]", "pydantic >=2.10", "sqlalchemy[aiomysql,aiosqlite] >= 2", @@ -27,7 +25,7 @@ testing = [ "diracx-testing", ] -[project.entry-points."diracx.db.sql"] +[project.entry-points."diracx.dbs.sql"] AuthDB = "diracx.db.sql:AuthDB" JobDB = "diracx.db.sql:JobDB" JobLoggingDB = "diracx.db.sql:JobLoggingDB" @@ -35,7 +33,7 @@ PilotAgentsDB = "diracx.db.sql:PilotAgentsDB" SandboxMetadataDB = "diracx.db.sql:SandboxMetadataDB" TaskQueueDB = "diracx.db.sql:TaskQueueDB" -[project.entry-points."diracx.db.os"] +[project.entry-points."diracx.dbs.os"] JobParametersDB = "diracx.db.os:JobParametersDB" [tool.setuptools.packages.find] diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index 431cceaa..2b345058 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -38,7 +38,7 @@ class BaseOSDB(metaclass=ABCMeta): The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`. This method returns a dictionary of database names to connection parameters. - The available databases are determined by the `diracx.db.os` entrypoint in + The available databases are determined by the `diracx.dbs.os` entrypoint in the `pyproject.toml` file and the connection parameters are taken from the environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`. @@ -92,7 +92,9 @@ def available_implementations(cls, db_name: str) -> list[type[BaseOSDB]]: """Return the available implementations of the DB in reverse priority order.""" db_classes: list[type[BaseOSDB]] = [ entry_point.load() - for entry_point in select_from_extension(group="diracx.db.os", name=db_name) + for entry_point in select_from_extension( + group="diracx.dbs.os", name=db_name + ) ] if not db_classes: raise NotImplementedError(f"Could not find any matches for {db_name=}") @@ -106,7 +108,7 @@ def available_urls(cls) -> dict[str, dict[str, Any]]: prefixed with ``DIRACX_OS_DB_{DB_NAME}``. """ conn_kwargs: dict[str, dict[str, Any]] = {} - for entry_point in select_from_extension(group="diracx.db.os"): + for entry_point in select_from_extension(group="diracx.dbs.os"): db_name = entry_point.name var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}" if var_name in os.environ: diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index b587f869..55a7de36 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -1,19 +1,16 @@ from __future__ import annotations -import hashlib import secrets -from datetime import datetime -from uuid import uuid4 +from uuid import UUID, uuid4 from sqlalchemy import insert, select, update from sqlalchemy.exc import IntegrityError, NoResultFound from diracx.core.exceptions import ( AuthorizationError, - ExpiredFlowError, - PendingAuthorizationError, + TokenNotFoundError, ) -from diracx.db.sql.utils import BaseSQLDB, substract_date +from diracx.db.sql.utils import BaseSQLDB, hash, substract_date from .schema import ( AuthorizationFlows, @@ -50,44 +47,25 @@ async def device_flow_validate_user_code( return (await self.conn.execute(stmt)).scalar_one() - async def get_device_flow(self, device_code: str, max_validity: int): + async def get_device_flow(self, device_code: str): """:raises: NoResultFound""" # The with_for_update # prevents that the token is retrieved # multiple time concurrently - stmt = select( - DeviceFlows, - (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label( - "IsExpired" - ), - ).with_for_update() + stmt = select(DeviceFlows).with_for_update() stmt = stmt.where( - DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(), + DeviceFlows.device_code == hash(device_code), ) - res = dict((await self.conn.execute(stmt)).one()._mapping) - - if res["IsExpired"]: - raise ExpiredFlowError() - - if res["Status"] == FlowStatus.READY: - # Update the status to Done before returning - await self.conn.execute( - update(DeviceFlows) - .where( - DeviceFlows.device_code - == hashlib.sha256(device_code.encode()).hexdigest() - ) - .values(status=FlowStatus.DONE) - ) - return res - - if res["Status"] == FlowStatus.DONE: - raise AuthorizationError("Code was already used") - - if res["Status"] == FlowStatus.PENDING: - raise PendingAuthorizationError() + return dict((await self.conn.execute(stmt)).one()._mapping) - raise AuthorizationError("Bad state in device flow") + async def update_device_flow_status( + self, device_code: str, status: FlowStatus + ) -> None: + stmt = update(DeviceFlows).where( + DeviceFlows.device_code == hash(device_code), + ) + stmt = stmt.values(status=status) + await self.conn.execute(stmt) async def device_flow_insert_id_token( self, user_code: str, id_token: dict[str, str], max_validity: int @@ -121,7 +99,7 @@ async def insert_device_flow( device_code = secrets.token_urlsafe() # Hash the the device_code to avoid leaking information - hashed_device_code = hashlib.sha256(device_code.encode()).hexdigest() + hashed_device_code = hash(device_code) stmt = insert(DeviceFlows).values( client_id=client_id, @@ -171,7 +149,7 @@ async def authorization_flow_insert_id_token( """ # Hash the code to avoid leaking information code = secrets.token_urlsafe() - hashed_code = hashlib.sha256(code.encode()).hexdigest() + hashed_code = hash(code) stmt = update(AuthorizationFlows) @@ -193,7 +171,8 @@ async def authorization_flow_insert_id_token( return code, row.RedirectURI async def get_authorization_flow(self, code: str, max_validity: int): - hashed_code = hashlib.sha256(code.encode()).hexdigest() + """Get the authorization flow details based on the code.""" + hashed_code = hash(code) # The with_for_update # prevents that the token is retrieved # multiple time concurrently @@ -203,54 +182,41 @@ async def get_authorization_flow(self, code: str, max_validity: int): AuthorizationFlows.creation_time > substract_date(seconds=max_validity), ) - res = dict((await self.conn.execute(stmt)).one()._mapping) - - if res["Status"] == FlowStatus.READY: - # Update the status to Done before returning - await self.conn.execute( - update(AuthorizationFlows) - .where(AuthorizationFlows.code == hashed_code) - .values(status=FlowStatus.DONE) - ) - - return res + return dict((await self.conn.execute(stmt)).one()._mapping) - if res["Status"] == FlowStatus.DONE: - raise AuthorizationError("Code was already used") - - raise AuthorizationError("Bad state in authorization flow") + async def update_authorization_flow_status( + self, code: str, status: FlowStatus + ) -> None: + """Update the status of an authorization flow based on the code.""" + hashed_code = hash(code) + await self.conn.execute( + update(AuthorizationFlows) + .where(AuthorizationFlows.code == hashed_code) + .values(status=status) + ) async def insert_refresh_token( self, + jti: UUID, subject: str, preferred_username: str, scope: str, - ) -> tuple[str, datetime]: + ) -> None: """Insert a refresh token in the DB as well as user attributes required to generate access tokens. """ - # Generate a JWT ID - jti = str(uuid4()) - # Insert values into the DB stmt = insert(RefreshTokens).values( - jti=jti, + jti=str(jti), sub=subject, preferred_username=preferred_username, scope=scope, ) await self.conn.execute(stmt) - # Get the creation time of the new tuple: generated by the insert operation - stmt = select(RefreshTokens.creation_time) - stmt = stmt.where(RefreshTokens.jti == jti) - row = (await self.conn.execute(stmt)).one() - - # Return the JWT ID and the creation time - return jti, row.CreationTime - - async def get_refresh_token(self, jti: str) -> dict: + async def get_refresh_token(self, jti: UUID) -> dict: """Get refresh token details bound to a given JWT ID.""" + jti = str(jti) # The with_for_update # prevents that the token is retrieved # multiple time concurrently @@ -260,8 +226,8 @@ async def get_refresh_token(self, jti: str) -> dict: ) try: res = dict((await self.conn.execute(stmt)).one()._mapping) - except NoResultFound: - return {} + except NoResultFound as e: + raise TokenNotFoundError(jti) from e return res @@ -285,11 +251,11 @@ async def get_user_refresh_tokens(self, subject: str | None = None) -> list[dict return refresh_tokens - async def revoke_refresh_token(self, jti: str): + async def revoke_refresh_token(self, jti: UUID): """Revoke a token given by its JWT ID.""" await self.conn.execute( update(RefreshTokens) - .where(RefreshTokens.jti == jti) + .where(RefreshTokens.jti == str(jti)) .values(status=RefreshTokenStatus.REVOKED) ) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 7b918157..0f6e1b3f 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -4,14 +4,12 @@ from typing import TYPE_CHECKING, Any from sqlalchemy import bindparam, case, delete, func, insert, select, update -from sqlalchemy.exc import IntegrityError, NoResultFound if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter -from diracx.core.exceptions import InvalidQueryError, JobNotFoundError +from diracx.core.exceptions import InvalidQueryError from diracx.core.models import ( - LimitedJobStatusReturn, SearchSpec, SortSpec, ) @@ -46,6 +44,7 @@ class JobDB(BaseSQLDB): jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"] async def summary(self, group_by, search) -> list[dict[str, str | int]]: + """Get a summary of the jobs.""" columns = _get_columns(Jobs.__table__, group_by) stmt = select(*columns, func.count(Jobs.job_id).label("count")) @@ -69,6 +68,7 @@ async def search( per_page: int = 100, page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: + """Search for jobs in the database.""" # Find which columns to select columns = _get_columns(Jobs.__table__, parameters) @@ -98,7 +98,24 @@ async def search( dict(row._mapping) async for row in (await self.conn.stream(stmt)) ] + async def create_job(self, compressed_original_jdl: str): + """Used to insert a new job with original JDL. Returns inserted job id.""" + result = await self.conn.execute( + JobJDLs.__table__.insert().values( + JDL="", + JobRequirements="", + OriginalJDL=compressed_original_jdl, + ) + ) + return result.lastrowid + + async def delete_jobs(self, job_ids: list[int]): + """Delete jobs from the database.""" + stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids)) + await self.conn.execute(stmt) + async def insert_input_data(self, lfns: dict[int, list[str]]): + """Insert input data for jobs.""" await self.conn.execute( InputData.__table__.insert(), [ @@ -111,27 +128,8 @@ async def insert_input_data(self, lfns: dict[int, list[str]]): ], ) - async def set_job_attributes(self, job_id, job_data): - """TODO: add myDate and force parameters.""" - if "Status" in job_data: - job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)} - stmt = update(Jobs).where(Jobs.job_id == job_id).values(job_data) - await self.conn.execute(stmt) - - async def create_job(self, original_jdl): - """Used to insert a new job with original JDL. Returns inserted job id.""" - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL - - result = await self.conn.execute( - JobJDLs.__table__.insert().values( - JDL="", - JobRequirements="", - OriginalJDL=compressJDL(original_jdl), - ) - ) - return result.lastrowid - async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): + """Insert the job attributes.""" await self.conn.execute( Jobs.__table__.insert(), [ @@ -145,8 +143,6 @@ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]): async def update_job_jdls(self, jdls_to_update: dict[int, str]): """Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example.""" - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL - await self.conn.execute( JobJDLs.__table__.update().where( JobJDLs.__table__.c.JobID == bindparam("b_JobID") @@ -154,67 +150,15 @@ async def update_job_jdls(self, jdls_to_update: dict[int, str]): [ { "b_JobID": job_id, - "JDL": compressJDL(jdl), + "JDL": compressed_jdl, } - for job_id, jdl in jdls_to_update.items() + for job_id, compressed_jdl in jdls_to_update.items() ], ) - async def check_and_prepare_job( - self, - job_id, - class_ad_job, - class_ad_req, - owner, - owner_group, - job_attrs, - vo, - ): - """Check Consistency of Submitted JDL and set some defaults - Prepare subJDL with Job Requirements. - """ - from DIRAC.Core.Utilities.DErrno import EWMSSUBM, cmpError - from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import ( - checkAndPrepareJob, - ) - - ret_val = checkAndPrepareJob( - job_id, - class_ad_job, - class_ad_req, - owner, - owner_group, - job_attrs, - vo, - ) - - if not ret_val["OK"]: - if cmpError(ret_val, EWMSSUBM): - await self.set_job_attributes(job_id, job_attrs) - - returnValueOrRaise(ret_val) - - async def set_job_jdl(self, job_id, jdl): - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL - - stmt = ( - update(JobJDLs).where(JobJDLs.job_id == job_id).values(JDL=compressJDL(jdl)) - ) - await self.conn.execute(stmt) - - async def set_job_jdl_bulk(self, jdls): - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL - - await self.conn.execute( - JobJDLs.__table__.update().where( - JobJDLs.__table__.c.JobID == bindparam("b_JobID") - ), - [{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()], - ) - - async def set_job_attributes_bulk(self, job_data): - """TODO: add myDate and force parameters.""" + async def set_job_attributes(self, job_data): + """Update the parameters of the given jobs.""" + # TODO: add myDate and force parameters. for job_id in job_data.keys(): if "Status" in job_data[job_id]: job_data[job_id].update( @@ -240,11 +184,8 @@ async def set_job_attributes_bulk(self, job_data): ) await self.conn.execute(stmt) - async def get_job_jdls( - self, job_ids, original: bool = False - ) -> dict[int | str, str]: - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL - + async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int, str]: + """Get the JDLs for the given jobs.""" if original: stmt = select(JobJDLs.job_id, JobJDLs.original_jdl).where( JobJDLs.job_id.in_(job_ids) @@ -254,37 +195,9 @@ async def get_job_jdls( JobJDLs.job_id.in_(job_ids) ) - return { - jobid: extractJDL(jdl) - for jobid, jdl in (await self.conn.execute(stmt)) - if jdl - } - - async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: - try: - stmt = select( - Jobs.status, Jobs.minor_status, Jobs.application_status - ).where(Jobs.job_id == job_id) - return LimitedJobStatusReturn( - **dict((await self.conn.execute(stmt)).one()._mapping) - ) - except NoResultFound as e: - raise JobNotFoundError(job_id) from e - - async def set_job_command(self, job_id: int, command: str, arguments: str = ""): - """Store a command to be passed to the job together with the next heart beat.""" - try: - stmt = insert(JobCommands).values( - JobID=job_id, - Command=command, - Arguments=arguments, - ReceptionTime=datetime.now(tz=timezone.utc), - ) - await self.conn.execute(stmt) - except IntegrityError as e: - raise JobNotFoundError(job_id) from e + return {jobid: jdl for jobid, jdl in (await self.conn.execute(stmt)) if jdl} - async def set_job_command_bulk(self, commands): + async def set_job_commands(self, commands: list[tuple[int, str, str]]): """Store a command to be passed to the job together with the next heart beat.""" await self.conn.execute( insert(JobCommands), @@ -298,12 +211,6 @@ async def set_job_command_bulk(self, commands): for job_id, command, arguments in commands ], ) - # FIXME handle IntegrityError - - async def delete_jobs(self, job_ids: list[int]): - """Delete jobs from the database.""" - stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids)) - await self.conn.execute(stmt) async def set_properties( self, properties: dict[int, dict[str, Any]], update_timestamp: bool = False diff --git a/diracx-db/src/diracx/db/sql/job_logging/db.py b/diracx-db/src/diracx/db/sql/job_logging/db.py index 154671e0..00f1d053 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/db.py +++ b/diracx-db/src/diracx/db/sql/job_logging/db.py @@ -1,20 +1,18 @@ from __future__ import annotations import time -from datetime import datetime, timezone +from datetime import timezone from typing import TYPE_CHECKING -from pydantic import BaseModel -from sqlalchemy import delete, func, insert, select +from sqlalchemy import delete, func, select if TYPE_CHECKING: pass from collections import defaultdict -from diracx.core.exceptions import JobNotFoundError from diracx.core.models import ( - JobStatus, + JobLoggingRecord, JobStatusReturn, ) @@ -27,61 +25,12 @@ MAGIC_EPOC_NUMBER = 1270000000 -class JobLoggingRecord(BaseModel): - job_id: int - status: JobStatus - minor_status: str - application_status: str - date: datetime - source: str - - class JobLoggingDB(BaseSQLDB): """Frontend for the JobLoggingDB. Provides the ability to store changes with timestamps.""" metadata = JobLoggingDBBase.metadata - async def insert_record( - self, - job_id: int, - status: JobStatus, - minor_status: str, - application_status: str, - date: datetime, - source: str, - ): - """Add a new entry to the JobLoggingDB table. One, two or all the three status - components (status, minorStatus, applicationStatus) can be specified. - Optionally the time stamp of the status can - be provided in a form of a string in a format '%Y-%m-%d %H:%M:%S' or - as datetime.datetime object. If the time stamp is not provided the current - UTC time is used. - """ - # First, fetch the maximum seq_num for the given job_id - seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)).where( - LoggingInfo.job_id == job_id - ) - seqnum = await self.conn.scalar(seqnum_stmt) - - epoc = ( - time.mktime(date.timetuple()) - + date.microsecond / 1000000.0 - - MAGIC_EPOC_NUMBER - ) - - stmt = insert(LoggingInfo).values( - job_id=int(job_id), - seq_num=seqnum, - status=status, - minor_status=minor_status, - application_status=application_status[:255], - status_time=date, - status_time_order=epoc, - source=source[:32], - ) - await self.conn.execute(stmt) - - async def bulk_insert_record( + async def insert_records( self, records: list[JobLoggingRecord], ): @@ -103,15 +52,20 @@ def get_epoc(date): .group_by(LoggingInfo.job_id) ) - seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))} + seqnums = { + jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt)) + } # IF a seqnum is not found, then assume it does not exist and the first sequence number is 1. # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements - await self.conn.execute( - LoggingInfo.__table__.insert(), - [ + values = [] + for record in records: + if record.job_id not in seqnums: + seqnums[record.job_id] = 1 + + values.append( { "JobID": record.job_id, - "SeqNum": seqnum.get(record.job_id, 1), + "SeqNum": seqnums[record.job_id], "Status": record.status, "MinorStatus": record.minor_status, "ApplicationStatus": record.application_status[:255], @@ -119,8 +73,12 @@ def get_epoc(date): "StatusTimeOrder": get_epoc(record.date), "StatusSource": record.source[:32], } - for record in records - ], + ) + seqnums[record.job_id] = seqnums[record.job_id] + 1 + + await self.conn.execute( + LoggingInfo.__table__.insert(), + values, ) async def get_records(self, job_ids: list[int]) -> dict[int, JobStatusReturn]: @@ -201,25 +159,7 @@ async def delete_records(self, job_ids: list[int]): stmt = delete(LoggingInfo).where(LoggingInfo.job_id.in_(job_ids)) await self.conn.execute(stmt) - async def get_wms_time_stamps(self, job_id): - """Get TimeStamps for job MajorState transitions - return a {State:timestamp} dictionary. - """ - result = {} - stmt = select( - LoggingInfo.status, - LoggingInfo.status_time_order, - ).where(LoggingInfo.job_id == job_id) - rows = await self.conn.execute(stmt) - if not rows.rowcount: - raise JobNotFoundError(job_id) from None - - for event, etime in rows: - result[event] = str(etime + MAGIC_EPOC_NUMBER) - - return result - - async def get_wms_time_stamps_bulk(self, job_ids): + async def get_wms_time_stamps(self, job_ids): """Get TimeStamps for job MajorState transitions for multiple jobs at once return a {JobID: {State:timestamp}} dictionary. """ diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py index 68ed181a..6e106eef 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py @@ -5,9 +5,14 @@ from sqlalchemy import Executable, delete, insert, literal, select, update from sqlalchemy.exc import IntegrityError, NoResultFound -from diracx.core.exceptions import SandboxAlreadyAssignedError, SandboxNotFoundError +from diracx.core.exceptions import ( + SandboxAlreadyAssignedError, + SandboxAlreadyInsertedError, + SandboxNotFoundError, +) from diracx.core.models import SandboxInfo, SandboxType, UserInfo -from diracx.db.sql.utils import BaseSQLDB, utcnow +from diracx.db.sql.utils.base import BaseSQLDB +from diracx.db.sql.utils.functions import utcnow from .schema import Base as SandboxMetadataDBBase from .schema import SandBoxes, SBEntityMapping, SBOwners @@ -16,18 +21,16 @@ class SandboxMetadataDB(BaseSQLDB): metadata = SandboxMetadataDBBase.metadata - async def upsert_owner(self, user: UserInfo) -> int: + async def get_owner_id(self, user: UserInfo) -> int | None: """Get the id of the owner from the database.""" - # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 stmt = select(SBOwners.OwnerID).where( SBOwners.Owner == user.preferred_username, SBOwners.OwnerGroup == user.dirac_group, SBOwners.VO == user.vo, ) - result = await self.conn.execute(stmt) - if owner_id := result.scalar_one_or_none(): - return owner_id + return (await self.conn.execute(stmt)).scalar_one_or_none() + async def insert_owner(self, user: UserInfo) -> int: stmt = insert(SBOwners).values( Owner=user.preferred_username, OwnerGroup=user.dirac_group, @@ -50,11 +53,9 @@ def get_pfn(bucket_name: str, user: UserInfo, sandbox_info: SandboxInfo) -> str: return "/" + "/".join(parts) async def insert_sandbox( - self, se_name: str, user: UserInfo, pfn: str, size: int + self, owner_id: int, se_name: str, pfn: str, size: int ) -> None: """Add a new sandbox in SandboxMetadataDB.""" - # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 - owner_id = await self.upsert_owner(user) stmt = insert(SandBoxes).values( OwnerId=owner_id, SEName=se_name, @@ -64,11 +65,9 @@ async def insert_sandbox( LastAccessTime=utcnow(), ) try: - result = await self.conn.execute(stmt) - except IntegrityError: - await self.update_sandbox_last_access_time(se_name, pfn) - else: - assert result.rowcount == 1 + await self.conn.execute(stmt) + except IntegrityError as e: + raise SandboxAlreadyInsertedError(pfn, se_name) from e async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None: stmt = ( diff --git a/diracx-db/src/diracx/db/sql/task_queue/db.py b/diracx-db/src/diracx/db/sql/task_queue/db.py index ff701509..cf4be438 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/db.py +++ b/diracx-db/src/diracx/db/sql/task_queue/db.py @@ -7,8 +7,6 @@ if TYPE_CHECKING: pass -from diracx.core.properties import JOB_SHARING, SecurityProperty - from ..utils import BaseSQLDB from .schema import ( BannedSitesQueue, @@ -49,126 +47,23 @@ async def get_owner_for_task_queue(self, tq_id: int) -> dict[str, str]: ) return dict((await self.conn.execute(stmt)).one()._mapping) - async def remove_job(self, job_id: int): - """Remove a job from the task queues.""" - stmt = delete(JobsQueue).where(JobsQueue.JobId == job_id) - await self.conn.execute(stmt) - - async def remove_jobs(self, job_ids: list[int]): - """Remove jobs from the task queues.""" - stmt = delete(JobsQueue).where(JobsQueue.JobId.in_(job_ids)) - await self.conn.execute(stmt) - - async def delete_task_queue_if_empty( - self, - tq_id: int, - tq_owner: str, - tq_group: str, - job_share: int, - group_properties: set[SecurityProperty], - enable_shares_correction: bool, - allow_background_tqs: bool, - ): - """Try to delete a task queue if it's empty.""" - # Check if the task queue is empty - stmt = ( - select(TaskQueues.TQId) - .where(TaskQueues.Enabled >= 1) - .where(TaskQueues.TQId == tq_id) - .where(~TaskQueues.TQId.in_(select(JobsQueue.TQId))) - ) - rows = await self.conn.execute(stmt) - if not rows.rowcount: - return - - # Deleting the task queue (the other tables will be deleted in cascade) - stmt = delete(TaskQueues).where(TaskQueues.TQId == tq_id) - await self.conn.execute(stmt) - - await self.recalculate_tq_shares_for_entity( - tq_owner, - tq_group, - job_share, - group_properties, - enable_shares_correction, - allow_background_tqs, - ) - - async def recalculate_tq_shares_for_entity( - self, - owner: str, - group: str, - job_share: int, - group_properties: set[SecurityProperty], - enable_shares_correction: bool, - allow_background_tqs: bool, - ): - """Recalculate the shares for a user/userGroup combo.""" - if JOB_SHARING in group_properties: - # If group has JobSharing just set prio for that entry, user is irrelevant - return await self.__set_priorities_for_entity( - owner, group, job_share, group_properties, allow_background_tqs - ) - + async def get_task_queue_owners_by_group(self, group: str) -> dict[str, int]: + """Get the owners for a task queue and group.""" stmt = ( select(TaskQueues.Owner, func.count(TaskQueues.Owner)) .where(TaskQueues.OwnerGroup == group) .group_by(TaskQueues.Owner) ) rows = await self.conn.execute(stmt) - # make the rows a list of tuples # Get owners in this group and the amount of times they appear # TODO: I guess the rows are already a list of tupes # maybe refactor - data = [(r[0], r[1]) for r in rows if r] - num_owners = len(data) - # If there are no owners do now - if num_owners == 0: - return - # Split the share amongst the number of owners - entities_shares = {row[0]: job_share / num_owners for row in data} - - # TODO: implement the following - # If corrector is enabled let it work it's magic - # if enable_shares_correction: - # entities_shares = await self.__shares_corrector.correct_shares( - # entitiesShares, group=group - # ) - - # Keep updating - owners = dict(data) - # IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified - # (The number of owners didn't change) - if owner in owners and owners[owner] > 1: - await self.__set_priorities_for_entity( - owner, - group, - entities_shares[owner], - group_properties, - allow_background_tqs, - ) - return - # Oops the number of owners may have changed so we recalculate the prio for all owners in the group - for owner in owners: - await self.__set_priorities_for_entity( - owner, - group, - entities_shares[owner], - group_properties, - allow_background_tqs, - ) - - async def __set_priorities_for_entity( - self, - owner: str, - group: str, - share, - properties: set[SecurityProperty], - allow_background_tqs: bool, - ): - """Set the priority for a user/userGroup combo given a splitted share.""" - from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import calculate_priority + return {r[0]: r[1] for r in rows if r} + async def get_task_queue_priorities( + self, group: str, owner: str | None = None + ) -> dict[int, float]: + """Get the priorities for a list of task queues.""" stmt = ( select( TaskQueues.TQId, @@ -178,24 +73,48 @@ async def __set_priorities_for_entity( .where(TaskQueues.OwnerGroup == group) .group_by(TaskQueues.TQId) ) - if JOB_SHARING not in properties: + if owner: stmt = stmt.where(TaskQueues.Owner == owner) rows = await self.conn.execute(stmt) - tq_dict: dict[int, float] = {tq_id: priority for tq_id, priority in rows} + return {tq_id: priority for tq_id, priority in rows} - if not tq_dict: - return + async def remove_jobs(self, job_ids: list[int]): + """Remove jobs from the task queues.""" + stmt = delete(JobsQueue).where(JobsQueue.JobId.in_(job_ids)) + await self.conn.execute(stmt) - rows = await self.retrieve_task_queues(list(tq_dict)) + async def is_task_queue_empty(self, tq_id: int) -> bool: + """Check if a task queue is empty.""" + stmt = ( + select(TaskQueues.TQId) + .where(TaskQueues.Enabled >= 1) + .where(TaskQueues.TQId == tq_id) + .where(~TaskQueues.TQId.in_(select(JobsQueue.TQId))) + ) + rows = await self.conn.execute(stmt) + return not rows.rowcount - prio_dict = calculate_priority(tq_dict, rows, share, allow_background_tqs) + async def delete_task_queue( + self, + tq_id: int, + ): + """Delete a task queue.""" + # Deleting the task queue (the other tables will be deleted in cascade) + stmt = delete(TaskQueues).where(TaskQueues.TQId == tq_id) + await self.conn.execute(stmt) - # Execute updates - for prio, tqs in prio_dict.items(): - update_stmt = ( - update(TaskQueues).where(TaskQueues.TQId.in_(tqs)).values(Priority=prio) - ) - await self.conn.execute(update_stmt) + async def set_priorities_for_entity( + self, + tq_ids: list[int], + priority: float, + ): + """Set the priority for a user/userGroup combo given a splitted share.""" + update_stmt = ( + update(TaskQueues) + .where(TaskQueues.TQId.in_(tq_ids)) + .values(Priority=priority) + ) + await self.conn.execute(update_stmt) async def retrieve_task_queues(self, tq_id_list=None): """Get all the task queues.""" diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index cd82d3c7..69b78b4b 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -6,7 +6,7 @@ apply_search_filters, apply_sort_constraints, ) -from .functions import substract_date, utcnow +from .functions import hash, substract_date, utcnow from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn __all__ = ( @@ -20,5 +20,6 @@ "apply_search_filters", "apply_sort_constraints", "substract_date", + "hash", "SQLDBUnavailableError", ) diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index dfe6baa8..b02b8ade 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -104,7 +104,7 @@ def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]: db_classes: list[type[BaseSQLDB]] = [ entry_point.load() for entry_point in select_from_extension( - group="diracx.db.sql", name=db_name + group="diracx.dbs.sql", name=db_name ) ] if not db_classes: @@ -119,7 +119,7 @@ def available_urls(cls) -> dict[str, str]: prefixed with ``DIRACX_DB_URL_{DB_NAME}``. """ db_urls: dict[str, str] = {} - for entry_point in select_from_extension(group="diracx.db.sql"): + for entry_point in select_from_extension(group="diracx.dbs.sql"): db_name = entry_point.name var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}" if var_name in os.environ: diff --git a/diracx-db/src/diracx/db/sql/utils/functions.py b/diracx-db/src/diracx/db/sql/utils/functions.py index b327c1ca..eaa4ae75 100644 --- a/diracx-db/src/diracx/db/sql/utils/functions.py +++ b/diracx-db/src/diracx/db/sql/utils/functions.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING @@ -104,3 +105,7 @@ def sqlite_date_trunc(element, compiler, **kw): def substract_date(**kwargs: float) -> datetime: return datetime.now(tz=timezone.utc) - timedelta(**kwargs) + + +def hash(code: str): + return hashlib.sha256(code.encode()).hexdigest() diff --git a/diracx-db/tests/auth/test_authorization_flow.py b/diracx-db/tests/auth/test_authorization_flow.py index 240cd55e..56a3d332 100644 --- a/diracx-db/tests/auth/test_authorization_flow.py +++ b/diracx-db/tests/auth/test_authorization_flow.py @@ -58,11 +58,6 @@ async def test_insert_id_token(auth_db: AuthDB): uuid, id_token, MAX_VALIDITY ) - # We shouldn't be able to retrieve it twice - async with auth_db as auth_db: - with pytest.raises(AuthorizationError, match="already used"): - res = await auth_db.get_authorization_flow(code, MAX_VALIDITY) - async def test_insert(auth_db: AuthDB): # First insert diff --git a/diracx-db/tests/auth/test_device_flow.py b/diracx-db/tests/auth/test_device_flow.py index 45093d2e..112e7789 100644 --- a/diracx-db/tests/auth/test_device_flow.py +++ b/diracx-db/tests/auth/test_device_flow.py @@ -1,13 +1,15 @@ from __future__ import annotations import secrets +from datetime import timezone import pytest from sqlalchemy.exc import NoResultFound -from diracx.core.exceptions import AuthorizationError, ExpiredFlowError +from diracx.core.exceptions import AuthorizationError from diracx.db.sql.auth.db import AuthDB from diracx.db.sql.auth.schema import USER_CODE_LENGTH +from diracx.db.sql.utils.functions import substract_date MAX_VALIDITY = 2 EXPIRED = 0 @@ -56,7 +58,7 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): async with auth_db as auth_db: with pytest.raises(NoResultFound): - await auth_db.get_device_flow("NotInserted", MAX_VALIDITY) + await auth_db.get_device_flow("NotInserted") # First insert async with auth_db as auth_db: @@ -76,17 +78,8 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): await auth_db.device_flow_validate_user_code(user_code1, EXPIRED) await auth_db.device_flow_validate_user_code(user_code1, MAX_VALIDITY) - - # Cannot get it with device_code because no id_token - with pytest.raises(AuthorizationError): - await auth_db.get_device_flow(device_code1, MAX_VALIDITY) - await auth_db.device_flow_validate_user_code(user_code2, MAX_VALIDITY) - # Cannot get it with device_code because no id_token - with pytest.raises(AuthorizationError): - await auth_db.get_device_flow(device_code2, MAX_VALIDITY) - async with auth_db as auth_db: with pytest.raises(AuthorizationError): await auth_db.device_flow_insert_id_token( @@ -103,18 +96,16 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch): user_code1, {"token": "mytoken2"}, MAX_VALIDITY ) - with pytest.raises(ExpiredFlowError): - await auth_db.get_device_flow(device_code1, EXPIRED) + res = await auth_db.get_device_flow(device_code1) + # The device code should be expired + assert res["CreationTime"].replace(tzinfo=timezone.utc) > substract_date( + seconds=MAX_VALIDITY + ) - res = await auth_db.get_device_flow(device_code1, MAX_VALIDITY) + res = await auth_db.get_device_flow(device_code1) assert res["UserCode"] == user_code1 assert res["IDToken"] == {"token": "mytoken"} - # cannot get it a second time - async with auth_db as auth_db: - with pytest.raises(AuthorizationError): - await auth_db.get_device_flow(device_code1, MAX_VALIDITY) - # Re-adding a token should not work after it's been minted async with auth_db as auth_db: with pytest.raises(AuthorizationError): @@ -146,5 +137,5 @@ async def test_device_flow_insert_id_token(auth_db: AuthDB): await auth_db.device_flow_validate_user_code(user_code, MAX_VALIDITY) async with auth_db as auth_db: - res = await auth_db.get_device_flow(device_code, MAX_VALIDITY) + res = await auth_db.get_device_flow(device_code) assert res["IDToken"] == id_token diff --git a/diracx-db/tests/auth/test_refresh_token.py b/diracx-db/tests/auth/test_refresh_token.py index 2d72cef0..4d0ad404 100644 --- a/diracx-db/tests/auth/test_refresh_token.py +++ b/diracx-db/tests/auth/test_refresh_token.py @@ -1,5 +1,7 @@ from __future__ import annotations +from uuid import UUID, uuid4 + import pytest from diracx.db.sql.auth.db import AuthDB @@ -18,16 +20,20 @@ async def auth_db(tmp_path): async def test_insert(auth_db: AuthDB): """Insert two refresh tokens in the DB and check that they don't share the same JWT ID.""" # Insert a first refresh token + jti1 = uuid4() async with auth_db as auth_db: - jti1, _ = await auth_db.insert_refresh_token( + await auth_db.insert_refresh_token( + jti1, "subject", "username", "vo:lhcb property:NormalUser", ) # Insert a second refresh token + jti2 = uuid4() async with auth_db as auth_db: - jti2, _ = await auth_db.insert_refresh_token( + await auth_db.insert_refresh_token( + jti2, "subject", "username", "vo:lhcb property:NormalUser", @@ -47,12 +53,15 @@ async def test_get(auth_db: AuthDB): } # Insert refresh token details + jti = uuid4() async with auth_db as auth_db: - jti, creation_time = await auth_db.insert_refresh_token( + await auth_db.insert_refresh_token( + jti, refresh_token_details["sub"], refresh_token_details["preferred_username"], refresh_token_details["scope"], ) + creation_time = (await auth_db.get_refresh_token(jti))["CreationTime"] # Enrich the dict with the generated refresh token attributes expected_refresh_token = { @@ -69,6 +78,7 @@ async def test_get(auth_db: AuthDB): result = await auth_db.get_refresh_token(jti) # Make sure they are identical + result["JTI"] = UUID(result["JTI"], version=4) assert result == expected_refresh_token @@ -87,6 +97,7 @@ async def test_get_user_refresh_tokens(auth_db: AuthDB): async with auth_db as auth_db: for sub in subjects: await auth_db.insert_refresh_token( + uuid4(), sub, "username", "scope", @@ -112,7 +123,9 @@ async def test_revoke(auth_db: AuthDB): """Insert a refresh token in the DB, revoke it, and make sure it appears as REVOKED in the db.""" # Insert a refresh token details async with auth_db as auth_db: - jti, _ = await auth_db.insert_refresh_token( + jti = uuid4() + await auth_db.insert_refresh_token( + jti, "subject", "username", "scope", @@ -142,6 +155,7 @@ async def test_revoke_user_refresh_tokens(auth_db: AuthDB): async with auth_db as auth_db: for sub in subjects: await auth_db.insert_refresh_token( + uuid4(), sub, "username", "scope", @@ -184,7 +198,9 @@ async def test_revoke_and_get_user_refresh_tokens(auth_db: AuthDB): jtis = [] async with auth_db as auth_db: for _ in range(nb_tokens): - jti, _ = await auth_db.insert_refresh_token( + jti = uuid4() + await auth_db.insert_refresh_token( + jti, sub, "username", "scope", @@ -232,6 +248,7 @@ async def test_get_refresh_tokens(auth_db: AuthDB): async with auth_db as auth_db: for sub in subjects: await auth_db.insert_refresh_token( + uuid4(), sub, "username", "scope", diff --git a/diracx-db/tests/jobs/test_job_db.py b/diracx-db/tests/jobs/test_job_db.py index 060bd7d8..e6ca58ce 100644 --- a/diracx-db/tests/jobs/test_job_db.py +++ b/diracx-db/tests/jobs/test_job_db.py @@ -1,8 +1,9 @@ from __future__ import annotations import pytest +from sqlalchemy.exc import IntegrityError -from diracx.core.exceptions import InvalidQueryError, JobNotFoundError +from diracx.core.exceptions import InvalidQueryError from diracx.core.models import ( ScalarSearchOperator, ScalarSearchSpec, @@ -12,7 +13,6 @@ VectorSearchSpec, ) from diracx.db.sql.job.db import JobDB -from diracx.db.sql.utils.job import JobSubmissionSpec, submit_jobs_jdl @pytest.fixture @@ -27,29 +27,29 @@ async def job_db(tmp_path): yield job_db -async def test_search_parameters(job_db): +@pytest.fixture +async def populated_job_db(job_db): + """Populate the in-memory JobDB with 100 jobs using DAL calls.""" + async with job_db as db: + jobs_to_insert = {} + # Insert 100 jobs directly via the DAL. + for i in range(100): + compressed_jdl = f"CompressedJDL{i}" + job_id = await db.create_job(compressed_jdl) + jobs_to_insert[job_id] = { + "JobID": job_id, + "Status": "New", + "Owner": f"owner{i}", + "OwnerGroup": "owner_group1" if i < 50 else "owner_group2", + "VO": "lhcb", + } + await db.insert_job_attributes(jobs_to_insert) + yield job_db + + +async def test_search_parameters(populated_job_db): """Test that we can search specific parameters for jobs in the database.""" - async with job_db as job_db: - total, result = await job_db.search(["JobID"], [], []) - assert total == 0 - assert not result - - result = await submit_jobs_jdl( - [ - JobSubmissionSpec( - jdl=f"JDL{i}", - owner="owner", - owner_group="owner_group", - initial_status="New", - initial_minor_status="dfdfds", - vo="lhcb", - ) - for i in range(100) - ], - job_db, - ) - - async with job_db as job_db: + async with populated_job_db as job_db: # Search a specific parameter: JobID total, result = await job_db.search(["JobID"], [], []) assert total == 100 @@ -81,25 +81,9 @@ async def test_search_parameters(job_db): total, result = await job_db.search(["Dummy"], [], []) -async def test_search_conditions(job_db): +async def test_search_conditions(populated_job_db): """Test that we can search for specific jobs in the database.""" - async with job_db as job_db: - result = await submit_jobs_jdl( - [ - JobSubmissionSpec( - jdl=f"JDL{i}", - owner=f"owner{i}", - owner_group="owner_group", - initial_status="New", - initial_minor_status="dfdfds", - vo="lhcb", - ) - for i in range(100) - ], - job_db, - ) - - async with job_db as job_db: + async with populated_job_db as job_db: # Search a specific scalar condition: JobID eq 3 condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=3 @@ -204,25 +188,9 @@ async def test_search_conditions(job_db): assert not result -async def test_search_sorts(job_db): +async def test_search_sorts(populated_job_db): """Test that we can search for jobs in the database and sort the results.""" - async with job_db as job_db: - result = await submit_jobs_jdl( - [ - JobSubmissionSpec( - jdl=f"JDL{i}", - owner=f"owner{i}", - owner_group="owner_group1" if i < 50 else "owner_group2", - initial_status="New", - initial_minor_status="dfdfds", - vo="lhcb", - ) - for i in range(100) - ], - job_db, - ) - - async with job_db as job_db: + async with populated_job_db as job_db: # Search and sort by JobID in ascending order sort = SortSpec(parameter="JobID", direction=SortDirection.ASC) total, result = await job_db.search([], [], [sort]) @@ -269,25 +237,9 @@ async def test_search_sorts(job_db): assert result[99]["JobID"] == 51 -async def test_search_pagination(job_db): +async def test_search_pagination(populated_job_db): """Test that we can search for jobs in the database.""" - async with job_db as job_db: - result = await submit_jobs_jdl( - [ - JobSubmissionSpec( - jdl=f"JDL{i}", - owner="owner", - owner_group="owner_group", - initial_status="New", - initial_minor_status="dfdfds", - vo="lhcb", - ) - for i in range(100) - ], - job_db, - ) - - async with job_db as job_db: + async with populated_job_db as job_db: # Search for the first 10 jobs total, result = await job_db.search([], [], [], per_page=10, page=1) assert total == 100 @@ -330,8 +282,8 @@ async def test_search_pagination(job_db): result = await job_db.search([], [], [], per_page=0, page=1) -async def test_set_job_command_invalid_job_id(job_db: JobDB): +async def test_set_job_commands_invalid_job_id(job_db: JobDB): """Test that setting a command for a non-existent job raises JobNotFound.""" async with job_db as job_db: - with pytest.raises(JobNotFoundError): - await job_db.set_job_command(123456, "test_command") + with pytest.raises(IntegrityError): + await job_db.set_job_commands([(123456, "test_command", "")]) diff --git a/diracx-db/tests/jobs/test_job_logging_db.py b/diracx-db/tests/jobs/test_job_logging_db.py index 0e2f815f..829e826d 100644 --- a/diracx-db/tests/jobs/test_job_logging_db.py +++ b/diracx-db/tests/jobs/test_job_logging_db.py @@ -4,7 +4,7 @@ import pytest -from diracx.core.models import JobStatus +from diracx.core.models import JobLoggingRecord, JobStatus from diracx.db.sql import JobLoggingDB @@ -23,31 +23,39 @@ async def test_insert_records(job_logging_db: JobLoggingDB): date = datetime.now(timezone.utc) # Act + records = [] for i in range(50): - await job_logging_db.insert_record( - i, - status=JobStatus.RECEIVED, - minor_status="received_minor_status", - application_status="application_status", - date=date, - source="pytest", + records.append( + JobLoggingRecord( + job_id=i, + status=JobStatus.RECEIVED, + minor_status="received_minor_status", + application_status="application_status", + date=date, + source="pytest", + ) ) - await job_logging_db.insert_record( - i, - status=JobStatus.SUBMITTING, - minor_status="submitted_minor_status", - application_status="application_status", - date=date, - source="pytest", + records.append( + JobLoggingRecord( + job_id=i, + status=JobStatus.SUBMITTING, + minor_status="submitted_minor_status", + application_status="application_status", + date=date, + source="pytest", + ) ) - await job_logging_db.insert_record( - i, - status=JobStatus.RUNNING, - minor_status="running_minor_status", - application_status="application_status", - date=date, - source="pytest", + records.append( + JobLoggingRecord( + job_id=i, + status=JobStatus.RUNNING, + minor_status="running_minor_status", + application_status="application_status", + date=date, + source="pytest", + ) ) + await job_logging_db.insert_records(records) # Assert res = await job_logging_db.get_records([i for i in range(50)]) diff --git a/diracx-db/tests/jobs/test_sandbox_metadata.py b/diracx-db/tests/jobs/test_sandbox_metadata.py index d7b6bcbd..45ff3d9a 100644 --- a/diracx-db/tests/jobs/test_sandbox_metadata.py +++ b/diracx-db/tests/jobs/test_sandbox_metadata.py @@ -7,7 +7,7 @@ import pytest import sqlalchemy -from diracx.core.exceptions import SandboxNotFoundError +from diracx.core.exceptions import SandboxAlreadyInsertedError, SandboxNotFoundError from diracx.core.models import SandboxInfo, UserInfo from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB from diracx.db.sql.sandbox_metadata.schema import SandBoxes, SBEntityMapping @@ -40,6 +40,7 @@ def test_get_pfn(sandbox_metadata_db: SandboxMetadataDB): async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB): + # TODO: DAL tests should be very simple, such complex tests should be handled in diracx-routers user_info = UserInfo( sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo" ) @@ -52,16 +53,24 @@ async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB): with pytest.raises(SandboxNotFoundError): await sandbox_metadata_db.sandbox_is_assigned(pfn1, "SandboxSE") + # Insert owner + async with sandbox_metadata_db: + owner_id = await sandbox_metadata_db.insert_owner(user_info) + assert owner_id == 1 + # Insert the sandbox async with sandbox_metadata_db: - await sandbox_metadata_db.insert_sandbox("SandboxSE", user_info, pfn1, 100) + await sandbox_metadata_db.insert_sandbox(owner_id, "SandboxSE", pfn1, 100) db_contents = await _dump_db(sandbox_metadata_db) owner_id1, last_access_time1 = db_contents[pfn1] - # Inserting again should update the last access time await asyncio.sleep(1) # The timestamp only has second precision async with sandbox_metadata_db: - await sandbox_metadata_db.insert_sandbox("SandboxSE", user_info, pfn1, 100) + with pytest.raises(SandboxAlreadyInsertedError): + await sandbox_metadata_db.insert_sandbox(owner_id, "SandboxSE", pfn1, 100) + + await sandbox_metadata_db.update_sandbox_last_access_time("SandboxSE", pfn1) + db_contents = await _dump_db(sandbox_metadata_db) owner_id2, last_access_time2 = db_contents[pfn1] assert owner_id1 == owner_id2 @@ -99,6 +108,7 @@ async def _dump_db( async def test_assign_and_unsassign_sandbox_to_jobs( sandbox_metadata_db: SandboxMetadataDB, ): + # TODO: DAL tests should be very simple, such complex tests should be handled in diracx-routers pfn = secrets.token_hex() user_info = UserInfo( sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo" @@ -107,7 +117,8 @@ async def test_assign_and_unsassign_sandbox_to_jobs( sandbox_se = "SandboxSE" # Insert the sandbox async with sandbox_metadata_db: - await sandbox_metadata_db.insert_sandbox(sandbox_se, user_info, pfn, 100) + owner_id = await sandbox_metadata_db.insert_owner(user_info) + await sandbox_metadata_db.insert_sandbox(owner_id, sandbox_se, pfn, 100) async with sandbox_metadata_db: stmt = sqlalchemy.select(SandBoxes.SBId, SandBoxes.SEPFN) diff --git a/diracx-logic/README.md b/diracx-logic/README.md new file mode 100644 index 00000000..e69de29b diff --git a/diracx-logic/pyproject.toml b/diracx-logic/pyproject.toml new file mode 100644 index 00000000..4485fedf --- /dev/null +++ b/diracx-logic/pyproject.toml @@ -0,0 +1,40 @@ +[project] +name = "diracx-logic" +description = "TODO" +readme = "README.md" +requires-python = ">=3.11" +keywords = [] +license = {text = "GPL-3.0-only"} +classifiers = [ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering", + "Topic :: System :: Distributed Computing", +] +dependencies = [ + "cachetools", + "dirac", + "diracx-core", + "pydantic >=2.10", +] +dynamic = ["version"] + +[project.optional-dependencies] +types = [ + "types-cachetools", +] + +[tool.setuptools.packages.find] +where = ["src"] + +[build-system] +requires = ["setuptools>=61", "wheel", "setuptools_scm>=8"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +root = ".." + + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/diracx-logic/src/diracx/logic/__init__.py b/diracx-logic/src/diracx/logic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-logic/src/diracx/logic/auth/__init__.py b/diracx-logic/src/diracx/logic/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-logic/src/diracx/logic/auth/authorize_code_flow.py b/diracx-logic/src/diracx/logic/auth/authorize_code_flow.py new file mode 100644 index 00000000..5c43fdfd --- /dev/null +++ b/diracx-logic/src/diracx/logic/auth/authorize_code_flow.py @@ -0,0 +1,99 @@ +"""Authorization code flow.""" + +from __future__ import annotations + +from typing import Literal + +from diracx.core.config import Config +from diracx.core.models import GrantType +from diracx.core.properties import SecurityProperty +from diracx.core.settings import AuthSettings +from diracx.db.sql import AuthDB + +from .utils import ( + decrypt_state, + get_token_from_iam, + initiate_authorization_flow_with_iam, + parse_and_validate_scope, +) + + +async def initiate_authorization_flow( + request_url: str, + code_challenge: str, + code_challenge_method: Literal["S256"], + client_id: str, + redirect_uri: str, + scope: str, + state: str, + auth_db: AuthDB, + config: Config, + settings: AuthSettings, + available_properties: set[SecurityProperty], +) -> str: + """Initiate the authorization flow.""" + if settings.dirac_client_id != client_id: + raise ValueError("Unrecognised client_id") + if redirect_uri not in settings.allowed_redirects: + raise ValueError("Unrecognised redirect_uri") + + # Parse and validate the scope + parsed_scope = parse_and_validate_scope(scope, config, available_properties) + + # Store the authorization flow details + uuid = await auth_db.insert_authorization_flow( + client_id, + scope, + code_challenge, + code_challenge_method, + redirect_uri, + ) + + # Initiate the authorization flow with the IAM + state_for_iam = { + "external_state": state, + "uuid": uuid, + "grant_type": GrantType.authorization_code.value, + } + + authorization_flow_url = await initiate_authorization_flow_with_iam( + config, + parsed_scope["vo"], + f"{request_url}/complete", + state_for_iam, + settings.state_key.fernet, + ) + + return authorization_flow_url + + +async def complete_authorization_flow( + code: str, + state: str, + request_url: str, + auth_db: AuthDB, + config: Config, + settings: AuthSettings, +) -> str: + """Complete the authorization flow.""" + # Decrypt the state to access user details + decrypted_state = decrypt_state(state, settings.state_key.fernet) + assert decrypted_state["grant_type"] == GrantType.authorization_code + + # Get the ID token from the IAM + id_token = await get_token_from_iam( + config, + decrypted_state["vo"], + code, + decrypted_state, + request_url, + ) + + # Store the ID token and redirect the user to the client's redirect URI + code, redirect_uri = await auth_db.authorization_flow_insert_id_token( + decrypted_state["uuid"], + id_token, + settings.authorization_flow_expiration_seconds, + ) + + return f"{redirect_uri}?code={code}&state={decrypted_state['external_state']}" diff --git a/diracx-logic/src/diracx/logic/auth/device_flow.py b/diracx-logic/src/diracx/logic/auth/device_flow.py new file mode 100644 index 00000000..1e35ad3e --- /dev/null +++ b/diracx-logic/src/diracx/logic/auth/device_flow.py @@ -0,0 +1,100 @@ +"""Device flow.""" + +from __future__ import annotations + +from diracx.core.config import Config +from diracx.core.models import GrantType, InitiateDeviceFlowResponse +from diracx.core.properties import SecurityProperty +from diracx.core.settings import AuthSettings +from diracx.db.sql import AuthDB + +from .utils import ( + decrypt_state, + get_token_from_iam, + initiate_authorization_flow_with_iam, + parse_and_validate_scope, +) + + +async def initiate_device_flow( + client_id: str, + scope: str, + verification_uri: str, + auth_db: AuthDB, + config: Config, + available_properties: set[SecurityProperty], + settings: AuthSettings, +) -> InitiateDeviceFlowResponse: + """Initiate the device flow against DIRAC authorization Server.""" + if settings.dirac_client_id != client_id: + raise ValueError("Unrecognised client ID") + + parse_and_validate_scope(scope, config, available_properties) + + user_code, device_code = await auth_db.insert_device_flow(client_id, scope) + + return { + "user_code": user_code, + "device_code": device_code, + "verification_uri_complete": f"{verification_uri}?user_code={user_code}", + "verification_uri": verification_uri, + "expires_in": settings.device_flow_expiration_seconds, + } + + +async def do_device_flow( + request_url: str, + auth_db: AuthDB, + user_code: str, + config: Config, + available_properties: set[SecurityProperty], + settings: AuthSettings, +) -> str: + """This is called as the verification URI for the device flow.""" + # Here we make sure the user_code actually exists + scope = await auth_db.device_flow_validate_user_code( + user_code, settings.device_flow_expiration_seconds + ) + parsed_scope = parse_and_validate_scope(scope, config, available_properties) + + redirect_uri = f"{request_url}/complete" + + state_for_iam = { + "grant_type": GrantType.device_code.value, + "user_code": user_code, + } + + authorization_flow_url = await initiate_authorization_flow_with_iam( + config, + parsed_scope["vo"], + redirect_uri, + state_for_iam, + settings.state_key.fernet, + ) + return authorization_flow_url + + +async def finish_device_flow( + request_url: str, + code: str, + state: str, + auth_db: AuthDB, + config: Config, + settings: AuthSettings, +): + """This the url callbacked by IAM/Checkin after the authorization + flow was granted. + """ + decrypted_state = decrypt_state(state, settings.state_key.fernet) + assert decrypted_state["grant_type"] == GrantType.device_code + + id_token = await get_token_from_iam( + config, + decrypted_state["vo"], + code, + decrypted_state, + request_url, + ) + await auth_db.device_flow_insert_id_token( + decrypted_state["user_code"], id_token, settings.device_flow_expiration_seconds + ) diff --git a/diracx-logic/src/diracx/logic/auth/management.py b/diracx-logic/src/diracx/logic/auth/management.py new file mode 100644 index 00000000..2f2207d0 --- /dev/null +++ b/diracx-logic/src/diracx/logic/auth/management.py @@ -0,0 +1,32 @@ +"""This module contains the auth management functions.""" + +from __future__ import annotations + +from uuid import UUID + +from diracx.db.sql import AuthDB + + +async def get_refresh_tokens( + auth_db: AuthDB, + subject: str | None, +) -> list: + """Get all refresh tokens bound to a given subject. If there is no subject, then + all the refresh tokens are retrieved. + """ + return await auth_db.get_user_refresh_tokens(subject) + + +async def revoke_refresh_token( + auth_db: AuthDB, + subject: str | None, + jti: UUID, +) -> str: + """Revoke a refresh token. If a subject is provided, then the refresh token must be owned by that subject.""" + res = await auth_db.get_refresh_token(jti) + + if subject and subject != res["Sub"]: + raise PermissionError("Cannot revoke a refresh token owned by someone else") + + await auth_db.revoke_refresh_token(jti) + return f"Refresh token {jti} revoked" diff --git a/diracx-logic/src/diracx/logic/auth/token.py b/diracx-logic/src/diracx/logic/auth/token.py new file mode 100644 index 00000000..16dc7c0b --- /dev/null +++ b/diracx-logic/src/diracx/logic/auth/token.py @@ -0,0 +1,396 @@ +"""Token endpoint implementation.""" + +from __future__ import annotations + +import base64 +import hashlib +import re +from datetime import datetime, timedelta, timezone +from uuid import UUID, uuid4 + +from authlib.jose import JsonWebToken + +from diracx.core.config import Config +from diracx.core.exceptions import ( + AuthorizationError, + ExpiredFlowError, + InvalidCredentialsError, + PendingAuthorizationError, +) +from diracx.core.models import ( + AccessTokenPayload, + GrantType, + RefreshTokenPayload, + TokenPayload, +) +from diracx.core.properties import SecurityProperty +from diracx.core.settings import AuthSettings +from diracx.db.sql import AuthDB +from diracx.db.sql.auth.schema import FlowStatus, RefreshTokenStatus +from diracx.db.sql.utils.functions import substract_date + +from .utils import ( + get_allowed_user_properties, + parse_and_validate_scope, + verify_dirac_refresh_token, +) + + +async def get_oidc_token( + grant_type: GrantType, + client_id: str, + auth_db: AuthDB, + config: Config, + settings: AuthSettings, + available_properties: set[SecurityProperty], + device_code: str | None = None, + code: str | None = None, + redirect_uri: str | None = None, + code_verifier: str | None = None, + refresh_token: str | None = None, +) -> tuple[AccessTokenPayload, RefreshTokenPayload]: + """Token endpoint to retrieve the token at the end of a flow.""" + legacy_exchange = False + + if grant_type == GrantType.device_code: + assert device_code is not None + oidc_token_info, scope = await get_oidc_token_info_from_device_flow( + device_code, client_id, auth_db, settings + ) + + elif grant_type == GrantType.authorization_code: + assert code is not None + assert code_verifier is not None + oidc_token_info, scope = await get_oidc_token_info_from_authorization_flow( + code, client_id, redirect_uri, code_verifier, auth_db, settings + ) + + elif grant_type == GrantType.refresh_token: + assert refresh_token is not None + ( + oidc_token_info, + scope, + legacy_exchange, + ) = await get_oidc_token_info_from_refresh_flow( + refresh_token, auth_db, settings + ) + else: + raise NotImplementedError(f"Grant type not implemented {grant_type}") + + # Get a TokenResponse to return to the user + return await exchange_token( + auth_db, + scope, + oidc_token_info, + config, + settings, + available_properties, + legacy_exchange=legacy_exchange, + ) + + +async def get_oidc_token_info_from_device_flow( + device_code: str, client_id: str, auth_db: AuthDB, settings: AuthSettings +) -> tuple[dict, str]: + """Get OIDC token information from the device flow DB and check few parameters before returning it.""" + info = await get_device_flow( + auth_db, device_code, settings.device_flow_expiration_seconds + ) + + if info["ClientID"] != client_id: + raise ValueError("Bad client_id") + + oidc_token_info = info["IDToken"] + scope = info["Scope"] + + # TODO: use HTTPException while still respecting the standard format + # required by the RFC + if info["Status"] != FlowStatus.READY: + # That should never ever happen + raise NotImplementedError(f"Unexpected flow status {info['status']!r}") + return (oidc_token_info, scope) + + +async def get_oidc_token_info_from_authorization_flow( + code: str, + client_id: str | None, + redirect_uri: str | None, + code_verifier: str, + auth_db: AuthDB, + settings: AuthSettings, +) -> tuple[dict, str]: + """Get OIDC token information from the authorization flow DB and check few parameters before returning it.""" + info = await get_authorization_flow( + auth_db, code, settings.authorization_flow_expiration_seconds + ) + if redirect_uri != info["RedirectURI"]: + raise ValueError("Invalid redirect_uri") + if client_id != info["ClientID"]: + raise ValueError("Bad client_id") + + # Check the code_verifier + try: + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .strip("=") + ) + except Exception as e: + raise ValueError("Malformed code_verifier") from e + + if code_challenge != info["CodeChallenge"]: + raise ValueError("Invalid code_challenge") + + oidc_token_info = info["IDToken"] + scope = info["Scope"] + + # TODO: use HTTPException while still respecting the standard format + # required by the RFC + if info["Status"] != FlowStatus.READY: + # That should never ever happen + raise NotImplementedError(f"Unexpected flow status {info['status']!r}") + + return (oidc_token_info, scope) + + +async def get_oidc_token_info_from_refresh_flow( + refresh_token: str, auth_db: AuthDB, settings: AuthSettings +) -> tuple[dict, str, bool]: + """Get OIDC token information from the refresh token DB and check few parameters before returning it.""" + # Decode the refresh token to get the JWT ID + jti, _, legacy_exchange = await verify_dirac_refresh_token(refresh_token, settings) + + # Get some useful user information from the refresh token entry in the DB + refresh_token_attributes = await auth_db.get_refresh_token(jti) + + sub = refresh_token_attributes["Sub"] + + # Check if the refresh token was obtained from the legacy_exchange endpoint + # If it is the case, we bypass the refresh token rotation mechanism + if not legacy_exchange: + # Refresh token rotation: https://datatracker.ietf.org/doc/html/rfc6749#section-10.4 + # Check that the refresh token has not been already revoked + # This might indicate that a potential attacker try to impersonate someone + # In such case, all the refresh tokens bound to a given user (subject) should be revoked + # Forcing the user to reauthenticate interactively through an authorization/device flow (recommended practice) + if refresh_token_attributes["Status"] == RefreshTokenStatus.REVOKED: + # Revoke all the user tokens from the subject + await auth_db.revoke_user_refresh_tokens(sub) + + # Commit here, otherwise the revokation operation will not be taken into account + # as we return an error to the user + await auth_db.conn.commit() + + raise InvalidCredentialsError( + "Revoked refresh token reused: potential attack detected. You must authenticate again" + ) + + # Part of the refresh token rotation mechanism: + # Revoke the refresh token provided, a new one needs to be generated + await auth_db.revoke_refresh_token(jti) + + # Build an ID token and get scope from the refresh token attributes received + oidc_token_info = { + # The sub attribute coming from the DB contains the VO name + # We need to remove it as if it were coming from an ID token from an external IdP + "sub": sub.split(":", 1)[1], + "preferred_username": refresh_token_attributes["PreferredUsername"], + } + scope = refresh_token_attributes["Scope"] + return (oidc_token_info, scope, legacy_exchange) + + +BASE_64_URL_SAFE_PATTERN = ( + r"(?:[A-Za-z0-9\-_]{4})*(?:[A-Za-z0-9\-_]{2}==|[A-Za-z0-9\-_]{3}=)?" +) +LEGACY_EXCHANGE_PATTERN = rf"Bearer diracx:legacy:({BASE_64_URL_SAFE_PATTERN})" + + +async def perform_legacy_exchange( + expected_api_key: str, + preferred_username: str, + scope: str, + authorization: str, + auth_db: AuthDB, + available_properties: set[SecurityProperty], + settings: AuthSettings, + config: Config, + expires_minutes: int | None = None, +) -> tuple[AccessTokenPayload, RefreshTokenPayload]: + """Endpoint used by legacy DIRAC to mint tokens for proxy -> token exchange.""" + if match := re.fullmatch(LEGACY_EXCHANGE_PATTERN, authorization): + raw_token = base64.urlsafe_b64decode(match.group(1)) + else: + raise ValueError("Invalid authorization header") + + if hashlib.sha256(raw_token).hexdigest() != expected_api_key: + raise InvalidCredentialsError("Invalid credentials") + + try: + parsed_scope = parse_and_validate_scope(scope, config, available_properties) + vo_users = config.Registry[parsed_scope["vo"]] + sub = vo_users.sub_from_preferred_username(preferred_username) + except (KeyError, ValueError) as e: + raise ValueError("Invalid scope or preferred_username") from e + + return await exchange_token( + auth_db, + scope, + {"sub": sub, "preferred_username": preferred_username}, + config, + settings, + available_properties, + refresh_token_expire_minutes=expires_minutes, + legacy_exchange=True, + ) + + +async def exchange_token( + auth_db: AuthDB, + scope: str, + oidc_token_info: dict, + config: Config, + settings: AuthSettings, + available_properties: set[SecurityProperty], + *, + refresh_token_expire_minutes: int | None = None, + legacy_exchange: bool = False, +) -> tuple[AccessTokenPayload, RefreshTokenPayload]: + """Method called to exchange the OIDC token for a DIRAC generated access token.""" + # Extract dirac attributes from the OIDC scope + parsed_scope = parse_and_validate_scope(scope, config, available_properties) + vo = parsed_scope["vo"] + dirac_group = parsed_scope["group"] + properties = parsed_scope["properties"] + + # Extract attributes from the OIDC token details + sub = oidc_token_info["sub"] + if user_info := config.Registry[vo].Users.get(sub): + preferred_username = user_info.PreferedUsername + else: + preferred_username = oidc_token_info.get("preferred_username", sub) + raise NotImplementedError( + "Dynamic registration of users is not yet implemented" + ) + + # Check that the subject is part of the dirac users + if sub not in config.Registry[vo].Groups[dirac_group].Users: + raise PermissionError( + f"User is not a member of the requested group ({preferred_username}, {dirac_group})" + ) + + # Check that the user properties are valid + allowed_user_properties = get_allowed_user_properties(config, sub, vo) + if not properties.issubset(allowed_user_properties): + raise PermissionError( + f"{' '.join(properties - allowed_user_properties)} are not valid properties " + f"for user {preferred_username}, available values: {' '.join(allowed_user_properties)}" + ) + + # Merge the VO with the subject to get a unique DIRAC sub + sub = f"{vo}:{sub}" + + # Insert the refresh token with user details into the RefreshTokens table + # User details are needed to regenerate access tokens later + jti, creation_time = await insert_refresh_token( + auth_db=auth_db, + subject=sub, + preferred_username=preferred_username, + scope=scope, + ) + + # Generate refresh token payload + if refresh_token_expire_minutes is None: + refresh_token_expire_minutes = settings.refresh_token_expire_minutes + refresh_payload: RefreshTokenPayload = { + "jti": str(jti), + "exp": creation_time + timedelta(minutes=refresh_token_expire_minutes), + # legacy_exchange is used to indicate that the original refresh token + # was obtained from the legacy_exchange endpoint + "legacy_exchange": legacy_exchange, + "dirac_policies": {}, + } + + # Generate access token payload + # For now, the access token is only used to access DIRAC services, + # therefore, the audience is not set and checked + access_payload: AccessTokenPayload = { + "sub": sub, + "vo": vo, + "iss": settings.token_issuer, + "dirac_properties": list(properties), + "jti": str(uuid4()), + "preferred_username": preferred_username, + "dirac_group": dirac_group, + "exp": creation_time + timedelta(minutes=settings.access_token_expire_minutes), + "dirac_policies": {}, + } + + return access_payload, refresh_payload + + +def create_token(payload: TokenPayload, settings: AuthSettings) -> str: + jwt = JsonWebToken(settings.token_algorithm) + encoded_jwt = jwt.encode( + {"alg": settings.token_algorithm}, payload, settings.token_key.jwk + ) + return encoded_jwt.decode("ascii") + + +async def insert_refresh_token( + auth_db: AuthDB, + subject: str, + preferred_username: str, + scope: str, +) -> tuple[UUID, datetime]: + """Insert a refresh token into the database and return the JWT ID and creation time.""" + # Generate a JWT ID + jti = uuid4() + + # Insert the refresh token into the DB + await auth_db.insert_refresh_token( + jti=jti, + subject=subject, + preferred_username=preferred_username, + scope=scope, + ) + + # Get the creation time of the refresh token + refresh_token = await auth_db.get_refresh_token(jti) + return jti, refresh_token["CreationTime"] + + +async def get_device_flow(auth_db: AuthDB, device_code: str, max_validity: int): + """Get the device flow from the DB and check few parameters before returning it.""" + res = await auth_db.get_device_flow(device_code) + + if res["CreationTime"].replace(tzinfo=timezone.utc) < substract_date( + seconds=max_validity + ): + raise ExpiredFlowError() + + if res["Status"] == FlowStatus.READY: + await auth_db.update_device_flow_status(device_code, FlowStatus.DONE) + return res + + if res["Status"] == FlowStatus.DONE: + raise AuthorizationError("Code was already used") + + if res["Status"] == FlowStatus.PENDING: + raise PendingAuthorizationError() + + raise AuthorizationError("Bad state in device flow") + + +async def get_authorization_flow(auth_db: AuthDB, code: str, max_validity: int): + """Get the authorization flow from the DB and check few parameters before returning it.""" + res = await auth_db.get_authorization_flow(code, max_validity) + + if res["Status"] == FlowStatus.READY: + await auth_db.update_authorization_flow_status(code, FlowStatus.DONE) + return res + + if res["Status"] == FlowStatus.DONE: + raise AuthorizationError("Code was already used") + + raise AuthorizationError("Bad state in authorization flow") diff --git a/diracx-logic/src/diracx/logic/auth/utils.py b/diracx-logic/src/diracx/logic/auth/utils.py new file mode 100644 index 00000000..20df4101 --- /dev/null +++ b/diracx-logic/src/diracx/logic/auth/utils.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import base64 +import hashlib +import json +import secrets +from uuid import UUID + +import httpx +from authlib.integrations.starlette_client import OAuthError +from authlib.jose import JsonWebKey, JsonWebToken +from authlib.oidc.core import IDToken +from cachetools import TTLCache +from cryptography.fernet import Fernet +from typing_extensions import TypedDict + +from diracx.core.config.schema import Config +from diracx.core.exceptions import AuthorizationError, IAMClientError, IAMServerError +from diracx.core.models import GrantType +from diracx.core.properties import SecurityProperty +from diracx.core.settings import AuthSettings + + +class ScopeInfoDict(TypedDict): + group: str + properties: set[str] + vo: str + + +_server_metadata_cache: TTLCache = TTLCache(maxsize=1024, ttl=3600) + + +async def get_server_metadata(url: str): + """Get the server metadata from the IAM.""" + server_metadata = _server_metadata_cache.get(url) + if server_metadata is None: + async with httpx.AsyncClient() as c: + res = await c.get(url) + if res.status_code != 200: + # TODO: Better error handling + raise NotImplementedError(res) + server_metadata = res.json() + _server_metadata_cache[url] = server_metadata + return server_metadata + + +def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str: + """Encrypt the state dict and return it as a string.""" + return cipher_suite.encrypt( + base64.urlsafe_b64encode(json.dumps(state_dict).encode()) + ).decode() + + +def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]: + """Decrypt the state string and return it as a dict.""" + try: + return json.loads( + base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode() + ) + except Exception as e: + raise AuthorizationError("Invalid state") from e + + +async def fetch_jwk_set(url: str): + """Fetch the JWK set from the IAM.""" + server_metadata = await get_server_metadata(url) + + jwks_uri = server_metadata.get("jwks_uri") + if not jwks_uri: + raise RuntimeError('Missing "jwks_uri" in metadata') + + async with httpx.AsyncClient() as c: + res = await c.get(jwks_uri) + if res.status_code != 200: + # TODO: Better error handling + raise NotImplementedError(res) + jwk_set = res.json() + + return JsonWebKey.import_key_set(jwk_set) + + +async def parse_id_token(config, vo, raw_id_token: str): + """Parse and validate the ID token from IAM.""" + server_metadata = await get_server_metadata( + config.Registry[vo].IdP.server_metadata_url + ) + alg_values = server_metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + jwk_set = await fetch_jwk_set(config.Registry[vo].IdP.server_metadata_url) + + token = JsonWebToken(alg_values).decode( + raw_id_token, + key=jwk_set, + claims_cls=IDToken, + claims_options={ + "iss": {"values": [server_metadata["issuer"]]}, + # The audience is a required parameter and is the client ID of the application + # https://openid.net/specs/openid-connect-core-1_0.html#IDToken + "aud": {"values": [config.Registry[vo].IdP.ClientID]}, + }, + ) + token.validate() + return token + + +async def initiate_authorization_flow_with_iam( + config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet +): + """Initiate the authorization flow with the IAM. Return the URL to redirect the user to. + + The state dict is encrypted and passed to the IAM. + It is then decrypted when the user is redirected back to the redirect_uri. + """ + # code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1 + code_verifier = secrets.token_hex() + + # code_challenge: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .replace("=", "") + ) + + server_metadata = await get_server_metadata( + config.Registry[vo].IdP.server_metadata_url + ) + + # Take these two from CS/.well-known + authorization_endpoint = server_metadata["authorization_endpoint"] + + # Encrypt the state and pass it to the IAM + # Needed to retrieve the original flow details when the user is redirected back to the redirect_uri + encrypted_state = encrypt_state( + state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite + ) + + url_params = [ + "response_type=code", + f"code_challenge={code_challenge}", + "code_challenge_method=S256", + f"client_id={config.Registry[vo].IdP.ClientID}", + f"redirect_uri={redirect_uri}", + "scope=openid%20profile", + f"state={encrypted_state}", + ] + authorization_flow_url = f"{authorization_endpoint}?{'&'.join(url_params)}" + return authorization_flow_url + + +async def get_token_from_iam( + config, vo: str, code: str, state: dict[str, str], redirect_uri: str +) -> dict[str, str]: + """Get the token from the IAM using the code and state. Return the ID token.""" + server_metadata = await get_server_metadata( + config.Registry[vo].IdP.server_metadata_url + ) + + # Take these two from CS/.well-known + token_endpoint = server_metadata["token_endpoint"] + + data = { + "grant_type": GrantType.authorization_code.value, + "client_id": config.Registry[vo].IdP.ClientID, + "code": code, + "code_verifier": state["code_verifier"], + "redirect_uri": redirect_uri, + } + + async with httpx.AsyncClient() as c: + res = await c.post( + token_endpoint, + data=data, + ) + if res.status_code >= 500: + raise IAMServerError("Failed to contact IAM server") + elif res.status_code >= 400: + raise IAMClientError("Failed to contact IAM server") + + raw_id_token = res.json()["id_token"] + # Extract the payload and verify it + try: + id_token = await parse_id_token( + config=config, + vo=vo, + raw_id_token=raw_id_token, + ) + except OAuthError: + raise + + return id_token + + +async def verify_dirac_refresh_token( + refresh_token: str, + settings: AuthSettings, +) -> tuple[UUID, float, bool]: + """Verify dirac user token and return a UserInfo class + Used for each API endpoint. + """ + jwt = JsonWebToken(settings.token_algorithm) + token = jwt.decode( + refresh_token, + key=settings.token_key.jwk, + ) + token.validate() + + return ( + UUID(token["jti"], version=4), + float(token["exp"]), + token["legacy_exchange"], + ) + + +def get_allowed_user_properties(config: Config, sub, vo: str) -> set[SecurityProperty]: + """Retrieve all properties of groups a user is registered in.""" + allowed_user_properties = set() + for group in config.Registry[vo].Groups: + if sub in config.Registry[vo].Groups[group].Users: + allowed_user_properties.update(config.Registry[vo].Groups[group].Properties) + return allowed_user_properties + + +def parse_and_validate_scope( + scope: str, config: Config, available_properties: set[SecurityProperty] +) -> ScopeInfoDict: + """Check: + * At most one VO + * At most one group + * group belongs to VO + * properties are known + return dict with group and properties. + + :raises: + * ValueError in case the scope isn't valide + """ + scopes = set(scope.split(" ")) + + groups = [] + properties = [] + vos = [] + unrecognised = [] + for scope in scopes: + if scope.startswith("group:"): + groups.append(scope.split(":", 1)[1]) + elif scope.startswith("property:"): + properties.append(scope.split(":", 1)[1]) + elif scope.startswith("vo:"): + vos.append(scope.split(":", 1)[1]) + else: + unrecognised.append(scope) + if unrecognised: + raise ValueError(f"Unrecognised scopes: {unrecognised}") + + if not vos: + available_vo_scopes = [repr(f"vo:{vo}") for vo in config.Registry] + raise ValueError( + "No vo scope requested, available values: " + f"{' '.join(available_vo_scopes)}" + ) + elif len(vos) > 1: + raise ValueError(f"Only one vo is allowed but got {vos}") + else: + vo = vos[0] + if vo not in config.Registry: + raise ValueError(f"VO {vo} is not known to this installation") + + if not groups: + # TODO: Handle multiple groups correctly + group = config.Registry[vo].DefaultGroup + elif len(groups) > 1: + raise ValueError(f"Only one DIRAC group allowed but got {groups}") + else: + group = groups[0] + if group not in config.Registry[vo].Groups: + raise ValueError(f"{group} not in {vo} groups") + + allowed_properties = config.Registry[vo].Groups[group].Properties + properties.extend([str(p) for p in allowed_properties]) + + if not set(properties).issubset(available_properties): + raise ValueError( + f"{set(properties)-set(available_properties)} are not valid properties" + ) + + return { + "group": group, + "properties": set(sorted(properties)), + "vo": vo, + } diff --git a/diracx-logic/src/diracx/logic/auth/well_known.py b/diracx-logic/src/diracx/logic/auth/well_known.py new file mode 100644 index 00000000..d0f4fe9e --- /dev/null +++ b/diracx-logic/src/diracx/logic/auth/well_known.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from diracx.core.config.schema import Config +from diracx.core.models import GroupInfo, Metadata, OpenIDConfiguration +from diracx.core.settings import AuthSettings + + +async def get_openid_configuration( + token_endpoint: str, + userinfo_endpoint: str, + authorization_endpoint: str, + device_authorization_endpoint: str, + config: Config, + settings: AuthSettings, +) -> OpenIDConfiguration: + """OpenID Connect discovery endpoint.""" + scopes_supported = [] + for vo in config.Registry: + scopes_supported.append(f"vo:{vo}") + scopes_supported += [f"group:{vo}" for vo in config.Registry[vo].Groups] + scopes_supported += [f"property:{p}" for p in settings.available_properties] + + return { + "issuer": settings.token_issuer, + "token_endpoint": token_endpoint, + "userinfo_endpoint": userinfo_endpoint, + "authorization_endpoint": authorization_endpoint, + "device_authorization_endpoint": device_authorization_endpoint, + "grant_types_supported": [ + "authorization_code", + "urn:ietf:params:oauth:grant-type:device_code", + ], + "scopes_supported": scopes_supported, + "response_types_supported": ["code"], + "token_endpoint_auth_signing_alg_values_supported": [settings.token_algorithm], + "token_endpoint_auth_methods_supported": ["none"], + "code_challenge_methods_supported": ["S256"], + } + + +async def get_installation_metadata( + config: Config, +) -> Metadata: + """Get metadata about the dirac installation.""" + metadata: Metadata = { + "virtual_organizations": {}, + } + for vo, vo_info in config.Registry.items(): + groups: dict[str, GroupInfo] = { + group: {"properties": sorted(group_info.Properties)} + for group, group_info in vo_info.Groups.items() + } + metadata["virtual_organizations"][vo] = { + "groups": groups, + "support": { + "message": vo_info.Support.Message, + "webpage": vo_info.Support.Webpage, + "email": vo_info.Support.Email, + }, + "default_group": vo_info.DefaultGroup, + } + + return metadata diff --git a/diracx-logic/src/diracx/logic/jobs/__init__.py b/diracx-logic/src/diracx/logic/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-logic/src/diracx/logic/jobs/query.py b/diracx-logic/src/diracx/logic/jobs/query.py new file mode 100644 index 00000000..e764f264 --- /dev/null +++ b/diracx-logic/src/diracx/logic/jobs/query.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +from typing import Any + +from diracx.core.config.schema import Config +from diracx.core.models import ( + JobSearchParams, + JobSummaryParams, + ScalarSearchOperator, +) +from diracx.db.os.job_parameters import JobParametersDB +from diracx.db.sql.job.db import JobDB +from diracx.db.sql.job_logging.db import JobLoggingDB + +logger = logging.getLogger(__name__) + + +MAX_PER_PAGE = 10000 + + +async def search( + config: Config, + job_db: JobDB, + job_parameters_db: JobParametersDB, + job_logging_db: JobLoggingDB, + preferred_username: str, + page: int = 1, + per_page: int = 100, + body: JobSearchParams | None = None, +) -> tuple[int, list[dict[str, Any]]]: + """Retrieve information about jobs.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = JobSearchParams() + + if query_logging_info := ("LoggingInfo" in (body.parameters or [])): + if body.parameters: + body.parameters.remove("LoggingInfo") + body.parameters = ["JobID"] + (body.parameters or []) + + # TODO: Apply all the job policy stuff properly using user_info + if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: + body.search.append( + { + "parameter": "Owner", + "operator": ScalarSearchOperator.EQUAL, + # TODO-385: https://github.com/DIRACGrid/diracx/issues/385 + # The value shoud be user_info.sub, + # but since we historically rely on the preferred_username + # we will keep using the preferred_username for now. + "value": preferred_username, + } + ) + + total, jobs = await job_db.search( + body.parameters, + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + + if query_logging_info: + job_logging_info = await job_logging_db.get_records( + [job["JobID"] for job in jobs] + ) + for job in jobs: + job.update({"LoggingInfo": job_logging_info[job["JobID"]]}) + + return total, jobs + + +async def summary( + config: Config, + job_db: JobDB, + preferred_username: str, + body: JobSummaryParams, +): + """Show information suitable for plotting.""" + if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: + body.search.append( + { + "parameter": "Owner", + "operator": ScalarSearchOperator.EQUAL, + # TODO-385: https://github.com/DIRACGrid/diracx/issues/385 + # The value shoud be user_info.sub, + # but since we historically rely on the preferred_username + # we will keep using the preferred_username for now. + "value": preferred_username, + } + ) + return await job_db.summary(body.grouping, body.search) diff --git a/diracx-logic/src/diracx/logic/jobs/sandboxes.py b/diracx-logic/src/diracx/logic/jobs/sandboxes.py new file mode 100644 index 00000000..5a97e107 --- /dev/null +++ b/diracx-logic/src/diracx/logic/jobs/sandboxes.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from typing import Literal + +from pyparsing import Any + +from diracx.core.exceptions import SandboxAlreadyInsertedError, SandboxNotFoundError +from diracx.core.models import ( + SandboxDownloadResponse, + SandboxInfo, + SandboxType, + SandboxUploadResponse, + UserInfo, +) +from diracx.core.s3 import ( + generate_presigned_upload, + s3_object_exists, +) +from diracx.core.settings import SandboxStoreSettings +from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB + +MAX_SANDBOX_SIZE_BYTES = 100 * 1024 * 1024 + +SANDBOX_PFN_REGEX = ( + # Starts with /S3/ or /SB:|/S3/ + r"^(:?SB:[A-Za-z]+\|)?/S3/[a-z0-9\.\-]{3,63}" + # Followed ////:. + r"(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\.[a-z0-9\.]+$" +) + + +async def initiate_sandbox_upload( + user_info: UserInfo, + sandbox_info: SandboxInfo, + sandbox_metadata_db: SandboxMetadataDB, + settings: SandboxStoreSettings, +) -> SandboxUploadResponse: + """Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + """ + pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info) + + # TODO: THis test should come first, but if we do + # the access policy will crash for not having been called + # so we need to find a way to ackownledge that + + if sandbox_info.size > MAX_SANDBOX_SIZE_BYTES: + raise ValueError( + f"Sandbox too large, maximum allowed is {MAX_SANDBOX_SIZE_BYTES} bytes" + ) + full_pfn = f"SB:{settings.se_name}|{pfn}" + + try: + exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned( + pfn, settings.se_name + ) + except SandboxNotFoundError: + # The sandbox doesn't exist in the database + pass + else: + # As sandboxes are registered in the DB before uploading to the storage + # backend we can't rely on their existence in the database to determine if + # they have been uploaded. Instead we check if the sandbox has been + # assigned to a job. If it has then we know it has been uploaded and we + # can avoid communicating with the storage backend. + if exists_and_assigned or s3_object_exists( + settings.s3_client, settings.bucket_name, pfn_to_key(pfn) + ): + await sandbox_metadata_db.update_sandbox_last_access_time( + settings.se_name, pfn + ) + return SandboxUploadResponse(pfn=full_pfn) + + upload_info = await generate_presigned_upload( + settings.s3_client, + settings.bucket_name, + pfn_to_key(pfn), + sandbox_info.checksum_algorithm, + sandbox_info.checksum, + sandbox_info.size, + settings.url_validity_seconds, + ) + await insert_sandbox( + sandbox_metadata_db, settings.se_name, user_info, pfn, sandbox_info.size + ) + + return SandboxUploadResponse(**upload_info, pfn=full_pfn) + + +async def get_sandbox_file( + pfn: str, + settings: SandboxStoreSettings, +) -> SandboxDownloadResponse: + """Get a presigned URL to download a sandbox file.""" + short_pfn = pfn.split("|", 1)[-1] + + # TODO: Support by name and by job id? + presigned_url = await settings.s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": settings.bucket_name, "Key": pfn_to_key(short_pfn)}, + ExpiresIn=settings.url_validity_seconds, + ) + return SandboxDownloadResponse( + url=presigned_url, expires_in=settings.url_validity_seconds + ) + + +async def get_job_sandboxes( + job_id: int, + sandbox_metadata_db: SandboxMetadataDB, +) -> dict[str, list[Any]]: + """Get input and output sandboxes of given job.""" + input_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job( + job_id, SandboxType.Input + ) + output_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job( + job_id, SandboxType.Output + ) + return {SandboxType.Input: input_sb, SandboxType.Output: output_sb} + + +async def get_job_sandbox( + job_id: int, + sandbox_metadata_db: SandboxMetadataDB, + sandbox_type: Literal["input", "output"], +) -> list[Any]: + """Get input or output sandbox of given job.""" + return await sandbox_metadata_db.get_sandbox_assigned_to_job( + job_id, SandboxType(sandbox_type.capitalize()) + ) + + +async def assign_sandbox_to_job( + job_id: int, + pfn: str, + sandbox_metadata_db: SandboxMetadataDB, + settings: SandboxStoreSettings, +): + """Map the pfn as output sandbox to job.""" + short_pfn = pfn.split("|", 1)[-1] + await sandbox_metadata_db.assign_sandbox_to_jobs( + jobs_ids=[job_id], + pfn=short_pfn, + sb_type=SandboxType.Output, + se_name=settings.se_name, + ) + + +async def unassign_jobs_sandboxes( + jobs_ids: list[int], + sandbox_metadata_db: SandboxMetadataDB, +): + """Delete bulk jobs sandbox mapping.""" + await sandbox_metadata_db.unassign_sandboxes_to_jobs(jobs_ids) + + +def pfn_to_key(pfn: str) -> str: + """Convert a PFN to a key for S3. + + This removes the leading "/S3/" from the PFN. + """ + return "/".join(pfn.split("/")[3:]) + + +async def insert_sandbox( + sandbox_metadata_db: SandboxMetadataDB, + se_name: str, + user: UserInfo, + pfn: str, + size: int, +) -> None: + """Add a new sandbox in SandboxMetadataDB.""" + # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 + owner_id = await sandbox_metadata_db.get_owner_id(user) + if owner_id is None: + owner_id = await sandbox_metadata_db.insert_owner(user) + + try: + await sandbox_metadata_db.insert_sandbox(owner_id, se_name, pfn, size) + except SandboxAlreadyInsertedError: + await sandbox_metadata_db.update_sandbox_last_access_time(se_name, pfn) diff --git a/diracx-db/src/diracx/db/sql/utils/job.py b/diracx-logic/src/diracx/logic/jobs/status.py similarity index 66% rename from diracx-db/src/diracx/db/sql/utils/job.py rename to diracx-logic/src/diracx/logic/jobs/status.py index 3ffc587a..63b1392a 100644 --- a/diracx-db/src/diracx/db/sql/utils/job.py +++ b/diracx-logic/src/diracx/logic/jobs/status.py @@ -1,17 +1,26 @@ from __future__ import annotations -import asyncio +import logging from collections import defaultdict -from copy import deepcopy from datetime import datetime, timezone from typing import Any from unittest.mock import MagicMock -from fastapi import BackgroundTasks -from pydantic import BaseModel +from DIRAC.Core.Utilities import TimeUtilities +from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd +from DIRAC.Core.Utilities.ReturnValues import SErrorException, returnValueOrRaise +from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import ( + compressJDL, + extractJDL, +) +from DIRAC.WorkloadManagementSystem.Utilities.JobStatusUtility import ( + getNewStatus, + getStartAndEndTime, +) from diracx.core.config.schema import Config from diracx.core.models import ( + JobLoggingRecord, JobMinorStatus, JobStatus, JobStatusUpdate, @@ -19,137 +28,236 @@ VectorSearchOperator, VectorSearchSpec, ) -from diracx.db.sql.job_logging.db import JobLoggingRecord +from diracx.db.sql.job.db import JobDB +from diracx.db.sql.job_logging.db import JobLoggingDB +from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB +from diracx.db.sql.task_queue.db import TaskQueueDB +from diracx.logic.jobs.utils import check_and_prepare_job +from diracx.logic.task_queues.priority import recalculate_tq_shares_for_entity -from .. import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB +logger = logging.getLogger(__name__) -class JobSubmissionSpec(BaseModel): - jdl: str - owner: str - owner_group: str - initial_status: str - initial_minor_status: str - vo: str +async def remove_jobs( + job_ids: list[int], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, +): + """Fully remove a list of jobs from the WMS databases.""" + # Remove the staging task from the StorageManager + # TODO: this was not done in the JobManagerHandler, but it was done in the kill method + # I think it should be done here too + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) + # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent + # I think it should be done here as well + await sandbox_metadata_db.unassign_sandboxes_to_jobs(job_ids) -async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB): - from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd - from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise - from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import ( - checkAndAddOwner, - createJDLWithInitialStatus, - ) + # Remove the job from TaskQueueDB + await remove_jobs_from_task_queue(job_ids, config, task_queue_db) - jobs_to_insert = {} - jdls_to_update = {} - inputdata_to_insert = {} - original_jdls = [] - - # generate the jobIDs first - # TODO: should ForgivingTaskGroup be used? - async with asyncio.TaskGroup() as tg: - for job in jobs: - original_jdl = deepcopy(job.jdl) - job_manifest = returnValueOrRaise( - checkAndAddOwner(original_jdl, job.owner, job.owner_group) - ) + # Remove the job from JobLoggingDB + await job_logging_db.delete_records(job_ids) - # Fix possible lack of brackets - if original_jdl.strip()[0] != "[": - original_jdl = f"[{original_jdl}]" + # Remove the job from JobDB + await job_db.delete_jobs(job_ids) - original_jdls.append( - ( - original_jdl, - job_manifest, - tg.create_task(job_db.create_job(original_jdl)), + +async def set_job_statuses( + status_changes: dict[int, dict[datetime, JobStatusUpdate]], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + force: bool = False, + additional_attributes: dict[int, dict[str, str]] = {}, +) -> SetJobStatusReturn: + """Set various status fields for job specified by its jobId. + Set only the last status in the JobDB, updating all the status + logging information in the JobLoggingDB. The status dict has datetime + as a key and status information dictionary as values. + + :raises: JobNotFound if the job is not found in one of the DBs + """ + # check that the datetime contains timezone info + for job_id, status in status_changes.items(): + for dt in status: + if dt.tzinfo is None: + raise ValueError( + f"Timestamp {dt} is not timezone aware for job {job_id}" ) - ) - async with asyncio.TaskGroup() as tg: - for job, (original_jdl, job_manifest_, job_id_task) in zip(jobs, original_jdls): - job_id = job_id_task.result() - job_attrs = { - "JobID": job_id, - "LastUpdateTime": datetime.now(tz=timezone.utc), - "SubmissionTime": datetime.now(tz=timezone.utc), - "Owner": job.owner, - "OwnerGroup": job.owner_group, - "VO": job.vo, + failed: dict[int, Any] = {} + deletable_killable_jobs = set() + job_attribute_updates: dict[int, dict[str, str]] = {} + job_logging_updates: list[JobLoggingRecord] = [] + status_dicts: dict[int, dict[datetime, dict[str, str]]] = defaultdict(dict) + + # transform JobStateUpdate objects into dicts + status_dicts = { + job_id: { + key: {k: v for k, v in value.model_dump().items() if v is not None} + for key, value in status.items() + } + for job_id, status in status_changes.items() + } + + # search all jobs at once + _, results = await job_db.search( + parameters=["Status", "StartExecTime", "EndExecTime", "JobID"], + search=[ + { + "parameter": "JobID", + "operator": VectorSearchOperator.IN, + "values": list(set(status_changes.keys())), } + ], + sorts=[], + ) + if not results: + return SetJobStatusReturn( + success={}, + failed={ + int(job_id): {"detail": "Not found"} for job_id in status_changes.keys() + }, + ) - job_manifest_.setOption("JobID", job_id) + found_jobs = set(int(res["JobID"]) for res in results) + failed.update( + { + int(nf_job_id): {"detail": "Not found"} + for nf_job_id in set(status_changes.keys()) - found_jobs + } + ) + # Get the latest time stamps of major status updates + wms_time_stamps = await job_logging_db.get_wms_time_stamps(found_jobs) - # 2.- Check JDL and Prepare DIRAC JDL - job_jdl = job_manifest_.dumpAsJDL() + for res in results: + job_id = int(res["JobID"]) + current_status = res["Status"] + start_time = res["StartExecTime"] + end_time = res["EndExecTime"] - # Replace the JobID placeholder if any - if job_jdl.find("%j") != -1: - job_jdl = job_jdl.replace("%j", str(job_id)) + # If the current status is Stalled and we get an update, it should probably be "Running" + if current_status == JobStatus.STALLED: + current_status = JobStatus.RUNNING - class_ad_job = ClassAd(job_jdl) + ##################################################################################################### + status_dict = status_dicts[job_id] + # This is more precise than "LastTime". time_stamps is a sorted list of tuples... + time_stamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items()) + last_time = TimeUtilities.fromEpoch(time_stamps[-1][0]).replace( + tzinfo=timezone.utc + ) - class_ad_req = ClassAd("[]") - if not class_ad_job.isOK(): - # Rollback the entire transaction - raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}") - # TODO: check if that is actually true - if class_ad_job.lookupAttribute("Parameters"): - raise NotImplementedError("Parameters in the JDL are not supported") + # Get chronological order of new updates + update_times = sorted(status_dict) - # TODO is this even needed? - class_ad_job.insertAttributeInt("JobID", job_id) + new_start_time, new_end_time = getStartAndEndTime( + start_time, end_time, update_times, time_stamps, status_dict + ) - await job_db.check_and_prepare_job( - job_id, - class_ad_job, - class_ad_req, - job.owner, - job.owner_group, - job_attrs, - job.vo, + job_data: dict[str, str] = {} + new_status: str | None = None + if update_times[-1] >= last_time: + new_status, new_minor, new_application = ( + returnValueOrRaise( # TODO: Catch this + getNewStatus( + job_id, + update_times, + last_time, + status_dict, + current_status, + force, + MagicMock(), # FIXME + ) + ) ) - job_jdl = createJDLWithInitialStatus( - class_ad_job, - class_ad_req, - job_db.jdl_2_db_parameters, - job_attrs, - job.initial_status, - job.initial_minor_status, - modern=True, + + if new_status: + job_data.update(additional_attributes.get(job_id, {})) + job_data["Status"] = new_status + job_data["LastUpdateTime"] = str(datetime.now(timezone.utc)) + if new_minor: + job_data["MinorStatus"] = new_minor + if new_application: + job_data["ApplicationStatus"] = new_application + + # TODO: implement elasticJobParametersDB ? + # if cls.elasticJobParametersDB: + # result = cls.elasticJobParametersDB.setJobParameter(int(jobID), "Status", status) + # if not result["OK"]: + # return result + + for upd_time in update_times: + if status_dict[upd_time]["Source"].startswith("Job"): + job_data["HeartBeatTime"] = str(upd_time) + + if not start_time and new_start_time: + job_data["StartExecTime"] = new_start_time + + if not end_time and new_end_time: + job_data["EndExecTime"] = new_end_time + + ##################################################################################################### + # delete or kill job, if we transition to DELETED or KILLED state + if new_status in [JobStatus.DELETED, JobStatus.KILLED]: + deletable_killable_jobs.add(job_id) + + # Update database tables + if job_data: + job_attribute_updates[job_id] = job_data + + for upd_time in update_times: + s_dict = status_dict[upd_time] + job_logging_updates.append( + JobLoggingRecord( + job_id=job_id, + status=s_dict.get("Status", "idem"), + minor_status=s_dict.get("MinorStatus", "idem"), + application_status=s_dict.get("ApplicationStatus", "idem"), + date=upd_time, + source=s_dict.get("Source", "Unknown"), + ) ) - jobs_to_insert[job_id] = job_attrs - jdls_to_update[job_id] = job_jdl + await job_db.set_job_attributes(job_attribute_updates) - if class_ad_job.lookupAttribute("InputData"): - input_data = class_ad_job.getListFromExpression("InputData") - inputdata_to_insert[job_id] = [lfn for lfn in input_data if lfn] + await remove_jobs_from_task_queue( + list(deletable_killable_jobs), + config, + task_queue_db, + ) - tg.create_task(job_db.update_job_jdls(jdls_to_update)) - tg.create_task(job_db.insert_job_attributes(jobs_to_insert)) + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) + + if deletable_killable_jobs: + await job_db.set_job_commands( + [(job_id, "Kill", "") for job_id in deletable_killable_jobs] + ) - if inputdata_to_insert: - tg.create_task(job_db.insert_input_data(inputdata_to_insert)) + await job_logging_db.insert_records(job_logging_updates) - return list(jobs_to_insert.keys()) + return SetJobStatusReturn( + success=job_attribute_updates, + failed=failed, + ) -async def reschedule_jobs_bulk( +async def reschedule_jobs( job_ids: list[int], config: Config, job_db: JobDB, job_logging_db: JobLoggingDB, task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - *, - reset_counter=False, -) -> dict[str, Any]: + reset_jobs: bool = False, +): """Reschedule given job.""" - from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd - from DIRAC.Core.Utilities.ReturnValues import SErrorException - failed = {} reschedule_max = config.Operations[ "Defaults" @@ -200,7 +308,7 @@ async def reschedule_jobs_bulk( # Noop continue - if reset_counter: + if reset_jobs: job_attrs["RescheduleCounter"] = 0 else: job_attrs["RescheduleCounter"] = int(job_attrs["RescheduleCounter"]) + 1 @@ -216,7 +324,6 @@ async def reschedule_jobs_bulk( failed[job_id] = { "detail": f"Maximum number of reschedules exceeded ({reschedule_max})" } - # DATABASE OPERATION (status change) continue jobs_to_resched[job_id] = job_attrs @@ -235,7 +342,7 @@ async def reschedule_jobs_bulk( # await self.delete_job_parameters(job_id) # await self.delete_job_optimizer_parameters(job_id) - def parse_jdl(job_id, job_jdl): + def parse_jdl(job_id: int, job_jdl: str): if not job_jdl.strip().startswith("["): job_jdl = f"[{job_jdl}]" class_ad_job = ClassAd(job_jdl) @@ -243,7 +350,7 @@ def parse_jdl(job_id, job_jdl): return class_ad_job job_jdls = { - jobid: parse_jdl(jobid, jdl) + jobid: parse_jdl(jobid, extractJDL(jdl)) for jobid, jdl in ( (await job_db.get_job_jdls(surviving_job_ids, original=True)).items() ) @@ -253,7 +360,7 @@ def parse_jdl(job_id, job_jdl): class_ad_job = job_jdls[job_id] class_ad_req = ClassAd("[]") try: - await job_db.check_and_prepare_job( + await check_and_prepare_job( job_id, class_ad_job, class_ad_req, @@ -261,6 +368,7 @@ def parse_jdl(job_id, job_jdl): jobs_to_resched[job_id]["OwnerGroup"], {"RescheduleCounter": jobs_to_resched[job_id]["RescheduleCounter"]}, class_ad_job.getAttributeString("VirtualOrganization"), + job_db, ) except SErrorException as e: failed[job_id] = {"detail": str(e)} @@ -293,7 +401,7 @@ def parse_jdl(job_id, job_jdl): } # set new JDL - jdl_changes[job_id] = job_jdl + jdl_changes[job_id] = compressJDL(job_jdl) # set new status status_changes[job_id] = { @@ -307,33 +415,35 @@ def parse_jdl(job_id, job_jdl): attribute_changes[job_id].update(additional_attrs) if surviving_job_ids: - # BULK STATUS UPDATE - # DATABASE OPERATION - set_job_status_result = await set_job_status_bulk( - status_changes, - config, - job_db, - job_logging_db, - task_queue_db, - background_task, + set_job_status_result = await set_job_statuses( + status_changes=status_changes, + config=config, + job_db=job_db, + job_logging_db=job_logging_db, + task_queue_db=task_queue_db, additional_attributes=attribute_changes, ) - # BULK JDL UPDATE - # DATABASE OPERATION - await job_db.set_job_jdl_bulk(jdl_changes) + await job_db.update_job_jdls(jdl_changes) + + success = {} + for job_id, set_status_result in set_job_status_result.success.items(): + if job_id in failed: + continue + + jdl = job_jdls.get(job_id, None) + if jdl: + jdl = jdl.asJDL() + + success[job_id] = { + "InputData": jdl, + **attribute_changes[job_id], + **set_status_result.model_dump(), + } return { "failed": failed, - "success": { - job_id: { - "InputData": job_jdls.get(job_id, None), - **attribute_changes[job_id], - **set_status_result.model_dump(), - } - for job_id, set_status_result in set_job_status_result.success.items() - if job_id not in failed - }, + "success": success, } return { @@ -342,237 +452,25 @@ def parse_jdl(job_id, job_jdl): } -async def set_job_status_bulk( - status_changes: dict[int, dict[datetime, JobStatusUpdate]], - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, - *, - force: bool = False, - additional_attributes: dict[int, dict[str, str]] = {}, -) -> SetJobStatusReturn: - """Set various status fields for job specified by its jobId. - Set only the last status in the JobDB, updating all the status - logging information in the JobLoggingDB. The status dict has datetime - as a key and status information dictionary as values. - - :raises: JobNotFound if the job is not found in one of the DBs - """ - from DIRAC.Core.Utilities import TimeUtilities - from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise - from DIRAC.WorkloadManagementSystem.Utilities.JobStatusUtility import ( - getNewStatus, - getStartAndEndTime, - ) - - failed: dict[int, Any] = {} - deletable_killable_jobs = set() - job_attribute_updates: dict[int, dict[str, str]] = {} - job_logging_updates: list[JobLoggingRecord] = [] - status_dicts: dict[int, dict[datetime, dict[str, str]]] = defaultdict(dict) - - # transform JobStateUpdate objects into dicts - status_dicts = { - job_id: { - key: {k: v for k, v in value.model_dump().items() if v is not None} - for key, value in status.items() - } - for job_id, status in status_changes.items() - } - - # search all jobs at once - _, results = await job_db.search( - parameters=["Status", "StartExecTime", "EndExecTime", "JobID"], - search=[ - { - "parameter": "JobID", - "operator": VectorSearchOperator.IN, - "values": list(set(status_changes.keys())), - } - ], - sorts=[], - ) - if not results: - return SetJobStatusReturn( - success={}, - failed={ - int(job_id): {"detail": "Not found"} for job_id in status_changes.keys() - }, - ) - - found_jobs = set(int(res["JobID"]) for res in results) - failed.update( - { - int(nf_job_id): {"detail": "Not found"} - for nf_job_id in set(status_changes.keys()) - found_jobs - } - ) - # Get the latest time stamps of major status updates - wms_time_stamps = await job_logging_db.get_wms_time_stamps_bulk(found_jobs) - - for res in results: - job_id = int(res["JobID"]) - current_status = res["Status"] - start_time = res["StartExecTime"] - end_time = res["EndExecTime"] - - # If the current status is Stalled and we get an update, it should probably be "Running" - if current_status == JobStatus.STALLED: - current_status = JobStatus.RUNNING - - ##################################################################################################### - status_dict = status_dicts[job_id] - # This is more precise than "LastTime". time_stamps is a sorted list of tuples... - time_stamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items()) - last_time = TimeUtilities.fromEpoch(time_stamps[-1][0]).replace( - tzinfo=timezone.utc - ) - - # Get chronological order of new updates - update_times = sorted(status_dict) - - new_start_time, new_end_time = getStartAndEndTime( - start_time, end_time, update_times, time_stamps, status_dict - ) - - job_data: dict[str, str] = {} - new_status: str | None = None - if update_times[-1] >= last_time: - new_status, new_minor, new_application = ( - returnValueOrRaise( # TODO: Catch this - getNewStatus( - job_id, - update_times, - last_time, - status_dict, - current_status, - force, - MagicMock(), # FIXME - ) - ) - ) - - if new_status: - job_data.update(additional_attributes.get(job_id, {})) - job_data["Status"] = new_status - job_data["LastUpdateTime"] = str(datetime.now(timezone.utc)) - if new_minor: - job_data["MinorStatus"] = new_minor - if new_application: - job_data["ApplicationStatus"] = new_application - - # TODO: implement elasticJobParametersDB ? - # if cls.elasticJobParametersDB: - # result = cls.elasticJobParametersDB.setJobParameter(int(jobID), "Status", status) - # if not result["OK"]: - # return result - - for upd_time in update_times: - if status_dict[upd_time]["Source"].startswith("Job"): - job_data["HeartBeatTime"] = str(upd_time) - - if not start_time and new_start_time: - job_data["StartExecTime"] = new_start_time - - if not end_time and new_end_time: - job_data["EndExecTime"] = new_end_time - - ##################################################################################################### - # delete or kill job, if we transition to DELETED or KILLED state - if new_status in [JobStatus.DELETED, JobStatus.KILLED]: - deletable_killable_jobs.add(job_id) - - # Update database tables - if job_data: - job_attribute_updates[job_id] = job_data - - for upd_time in update_times: - s_dict = status_dict[upd_time] - job_logging_updates.append( - JobLoggingRecord( - job_id=job_id, - status=s_dict.get("Status", "idem"), - minor_status=s_dict.get("MinorStatus", "idem"), - application_status=s_dict.get("ApplicationStatus", "idem"), - date=upd_time, - source=s_dict.get("Source", "Unknown"), - ) - ) - - await job_db.set_job_attributes_bulk(job_attribute_updates) - - await remove_jobs_from_task_queue( - list(deletable_killable_jobs), config, task_queue_db, background_task - ) - - # TODO: implement StorageManagerClient - # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) - - if deletable_killable_jobs: - await job_db.set_job_command_bulk( - [(job_id, "Kill", "") for job_id in deletable_killable_jobs] - ) - - await job_logging_db.bulk_insert_record(job_logging_updates) - - return SetJobStatusReturn( - success=job_attribute_updates, - failed=failed, - ) - - -async def remove_jobs( - job_ids: list[int], - config: Config, - job_db: JobDB, - job_logging_db: JobLoggingDB, - sandbox_metadata_db: SandboxMetadataDB, - task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, -): - """Fully remove a job from the WMS databases. - :raises: nothing. - """ - # Remove the staging task from the StorageManager - # TODO: this was not done in the JobManagerHandler, but it was done in the kill method - # I think it should be done here too - # TODO: implement StorageManagerClient - # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) - - # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent - # I think it should be done here as well - await sandbox_metadata_db.unassign_sandboxes_to_jobs(job_ids) - - # Remove the job from TaskQueueDB - await remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) - - # Remove the job from JobLoggingDB - await job_logging_db.delete_records(job_ids) - - # Remove the job from JobDB - await job_db.delete_jobs(job_ids) - - async def remove_jobs_from_task_queue( job_ids: list[int], config: Config, task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, ): """Remove the job from TaskQueueDB.""" - tq_infos = await task_queue_db.get_tq_infos_for_jobs(job_ids) await task_queue_db.remove_jobs(job_ids) + + tq_infos = await task_queue_db.get_tq_infos_for_jobs(job_ids) for tq_id, owner, owner_group, vo in tq_infos: # TODO: move to Celery - background_task.add_task( - task_queue_db.delete_task_queue_if_empty, - tq_id, - owner, - owner_group, - config.Registry[vo].Groups[owner_group].JobShare, - config.Registry[vo].Groups[owner_group].Properties, - config.Operations[vo].Services.JobScheduling.EnableSharesCorrection, - config.Registry[vo].Groups[owner_group].AllowBackgroundTQs, + + # If the task queue is not empty, do not remove it + if not task_queue_db.is_task_queue_empty(tq_id): + continue + + await task_queue_db.delete_task_queue(tq_id) + + # Recalculate shares for the owner group + await recalculate_tq_shares_for_entity( + owner, owner_group, vo, config, task_queue_db ) diff --git a/diracx-logic/src/diracx/logic/jobs/submission.py b/diracx-logic/src/diracx/logic/jobs/submission.py new file mode 100644 index 00000000..6e592b46 --- /dev/null +++ b/diracx-logic/src/diracx/logic/jobs/submission.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import asyncio +import logging +from copy import deepcopy +from datetime import datetime, timezone + +from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd +from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise +from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import ( + checkAndAddOwner, + compressJDL, + createJDLWithInitialStatus, +) +from DIRAC.WorkloadManagementSystem.Utilities.ParametricJob import ( + generateParametricJobs, + getParameterVectorLength, +) +from pydantic import BaseModel + +from diracx.core.models import ( + InsertedJob, + JobLoggingRecord, + JobStatus, + UserInfo, +) +from diracx.db.sql.job.db import JobDB +from diracx.db.sql.job_logging.db import JobLoggingDB +from diracx.logic.jobs.utils import check_and_prepare_job + +logger = logging.getLogger(__name__) + + +class JobSubmissionSpec(BaseModel): + jdl: str + owner: str + owner_group: str + initial_status: str + initial_minor_status: str + vo: str + + +MAX_PARAMETRIC_JOBS = 20 + + +async def submit_jdl_jobs( + job_definitions: list[str], + job_db: JobDB, + job_logging_db: JobLoggingDB, + user_info: UserInfo, +) -> list[InsertedJob]: + """Submit a list of JDLs to the JobDB.""" + # TODO: that needs to go in the legacy adapter (Does it ? Because bulk submission is not supported there) + for i in range(len(job_definitions)): + job_definition = job_definitions[i].strip() + if not (job_definition.startswith("[") and job_definition.endswith("]")): + job_definition = f"[{job_definition}]" + job_definitions[i] = job_definition + + if len(job_definitions) == 1: + # Check if the job is a parametric one + job_class_ad = ClassAd(job_definitions[0]) + result = getParameterVectorLength(job_class_ad) + if not result["OK"]: + # FIXME dont do this + print("Issue with getParameterVectorLength", result["Message"]) + return result + n_jobs = result["Value"] + parametric_job = False + if n_jobs is not None and n_jobs > 0: + # if we are here, then jobDesc was the description of a parametric job. So we start unpacking + parametric_job = True + result = generateParametricJobs(job_class_ad) + if not result["OK"]: + # FIXME why? + return result + job_desc_list = result["Value"] + else: + # if we are here, then jobDesc was the description of a single job. + job_desc_list = job_definitions + else: + # if we are here, then jobDesc is a list of JDLs + # we need to check that none of them is a parametric + for job_definition in job_definitions: + res = getParameterVectorLength(ClassAd(job_definition)) + if not res["OK"]: + raise ValueError(res["Message"]) + + if res["Value"]: + raise ValueError("You cannot submit parametric jobs in a bulk fashion") + + job_desc_list = job_definitions + # parametric_job = True + parametric_job = False + + # TODO: make the max number of jobs configurable in the CS + if len(job_desc_list) > MAX_PARAMETRIC_JOBS: + raise ValueError( + f"Normal user cannot submit more than {MAX_PARAMETRIC_JOBS} jobs at once" + ) + + result = [] + + if parametric_job: + initial_status = JobStatus.SUBMITTING + initial_minor_status = "Bulk transaction confirmation" + else: + initial_status = JobStatus.RECEIVED + initial_minor_status = "Job accepted" + + try: + submitted_job_ids = await create_jdl_jobs( + [ + JobSubmissionSpec( + jdl=jdl, + owner=user_info.preferred_username, + owner_group=user_info.dirac_group, + initial_status=initial_status, + initial_minor_status=initial_minor_status, + vo=user_info.vo, + ) + for jdl in job_desc_list + ], + job_db=job_db, + ) + except ExceptionGroup as e: + raise ValueError("JDL syntax error") from e + + logging.debug( + f'Jobs added to the JobDB", "{submitted_job_ids} for {user_info.preferred_username}/{user_info.dirac_group}' + ) + + job_created_time = datetime.now(timezone.utc) + await job_logging_db.insert_records( + [ + JobLoggingRecord( + job_id=int(job_id), + status=initial_status, + minor_status=initial_minor_status, + application_status="Unknown", + date=job_created_time, + source="JobManager", + ) + for job_id in submitted_job_ids + ] + ) + + # if not parametric_job: + # self.__sendJobsToOptimizationMind(submitted_job_ids) + + return [ + InsertedJob( + JobID=job_id, + Status=initial_status, + MinorStatus=initial_minor_status, + TimeStamp=job_created_time, + ) + for job_id in submitted_job_ids + ] + + +async def create_jdl_jobs(jobs: list[JobSubmissionSpec], job_db: JobDB): + """Create jobs from JDLs and insert them into the DB.""" + jobs_to_insert = {} + jdls_to_update = {} + inputdata_to_insert = {} + original_jdls = [] + + # generate the jobIDs first + # TODO: should ForgivingTaskGroup be used? + async with asyncio.TaskGroup() as tg: + for job in jobs: + original_jdl = deepcopy(job.jdl) + job_manifest = returnValueOrRaise( + checkAndAddOwner(original_jdl, job.owner, job.owner_group) + ) + + # Fix possible lack of brackets + if original_jdl.strip()[0] != "[": + original_jdl = f"[{original_jdl}]" + + original_jdls.append( + ( + original_jdl, + job_manifest, + tg.create_task(job_db.create_job(compressJDL(original_jdl))), + ) + ) + + async with asyncio.TaskGroup() as tg: + for job, (original_jdl, job_manifest_, job_id_task) in zip(jobs, original_jdls): + job_id = job_id_task.result() + job_attrs = { + "JobID": job_id, + "LastUpdateTime": datetime.now(tz=timezone.utc), + "SubmissionTime": datetime.now(tz=timezone.utc), + "Owner": job.owner, + "OwnerGroup": job.owner_group, + "VO": job.vo, + } + + job_manifest_.setOption("JobID", job_id) + + # 2.- Check JDL and Prepare DIRAC JDL + job_jdl = job_manifest_.dumpAsJDL() + + # Replace the JobID placeholder if any + if job_jdl.find("%j") != -1: + job_jdl = job_jdl.replace("%j", str(job_id)) + + class_ad_job = ClassAd(job_jdl) + + class_ad_req = ClassAd("[]") + if not class_ad_job.isOK(): + # Rollback the entire transaction + raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}") + # TODO: check if that is actually true + if class_ad_job.lookupAttribute("Parameters"): + raise NotImplementedError("Parameters in the JDL are not supported") + + # TODO is this even needed? + class_ad_job.insertAttributeInt("JobID", job_id) + + await check_and_prepare_job( + job_id, + class_ad_job, + class_ad_req, + job.owner, + job.owner_group, + job_attrs, + job.vo, + job_db, + ) + job_jdl = createJDLWithInitialStatus( + class_ad_job, + class_ad_req, + job_db.jdl_2_db_parameters, + job_attrs, + job.initial_status, + job.initial_minor_status, + modern=True, + ) + + jobs_to_insert[job_id] = job_attrs + jdls_to_update[job_id] = compressJDL(job_jdl) + + if class_ad_job.lookupAttribute("InputData"): + input_data = class_ad_job.getListFromExpression("InputData") + inputdata_to_insert[job_id] = [lfn for lfn in input_data if lfn] + + tg.create_task(job_db.update_job_jdls(jdls_to_update)) + tg.create_task(job_db.insert_job_attributes(jobs_to_insert)) + + if inputdata_to_insert: + tg.create_task(job_db.insert_input_data(inputdata_to_insert)) + + return list(jobs_to_insert.keys()) diff --git a/diracx-logic/src/diracx/logic/jobs/utils.py b/diracx-logic/src/diracx/logic/jobs/utils.py new file mode 100644 index 00000000..e4a7ae5c --- /dev/null +++ b/diracx-logic/src/diracx/logic/jobs/utils.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd +from DIRAC.Core.Utilities.DErrno import EWMSSUBM, cmpError +from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise +from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import ( + checkAndPrepareJob, +) + +from diracx.db.sql.job.db import JobDB + + +async def check_and_prepare_job( + job_id: int, + class_ad_job: ClassAd, + class_ad_req: ClassAd, + owner: str, + owner_group: str, + job_attrs: dict, + vo: str, + job_db: JobDB, +): + """Check Consistency of Submitted JDL and set some defaults + Prepare subJDL with Job Requirements. + """ + ret_val = checkAndPrepareJob( + job_id, + class_ad_job, + class_ad_req, + owner, + owner_group, + job_attrs, + vo, + ) + + if not ret_val["OK"]: + if cmpError(ret_val, EWMSSUBM): + await job_db.set_job_attributes({job_id: job_attrs}) + + returnValueOrRaise(ret_val) diff --git a/diracx-logic/src/diracx/logic/py.typed b/diracx-logic/src/diracx/logic/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/diracx-logic/src/diracx/logic/task_queues/__init__.py b/diracx-logic/src/diracx/logic/task_queues/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diracx-logic/src/diracx/logic/task_queues/priority.py b/diracx-logic/src/diracx/logic/task_queues/priority.py new file mode 100644 index 00000000..352025df --- /dev/null +++ b/diracx-logic/src/diracx/logic/task_queues/priority.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any + +from diracx.core.config.schema import Config +from diracx.core.properties import JOB_SHARING +from diracx.db.sql.task_queue.db import TaskQueueDB + +TQ_MIN_SHARE = 0.001 +PRIORITY_IGNORED_FIELDS = ("Sites", "BannedSites") + + +async def recalculate_tq_shares_for_entity( + owner: str, + owner_group: str, + vo: str, + config: Config, + task_queue_db: TaskQueueDB, +): + """Recalculate the shares for a user/userGroup combo.""" + group_properties = config.Registry[vo].Groups[owner_group].Properties + job_share = config.Registry[vo].Groups[owner_group].JobShare + allow_background_tqs = config.Registry[vo].Groups[owner_group].AllowBackgroundTQs + if JOB_SHARING in group_properties: + # If group has JobSharing just set prio for that entry, user is irrelevant + await set_priorities_for_entity( + owner=None, + owner_group=owner_group, + job_share=job_share, + allow_background_tqs=allow_background_tqs, + task_queue_db=task_queue_db, + ) + return + + # Get all owners from the owner group + owners = await task_queue_db.get_task_queue_owners_by_group(owner_group) + num_owners = len(owners) + # If there are no owners do now + if num_owners == 0: + return + + # Split the share amongst the number of owners + entities_shares = {owner: job_share / num_owners for owner, _ in owners.items()} + + # TODO: implement the following + # If corrector is enabled let it work it's magic + # if enable_shares_correction: + # entities_shares = await self.__shares_corrector.correct_shares( + # entitiesShares, group=group + # ) + + # If the user is already known and has more than 1 tq, the rest of the users don't need to be modified + # (The number of owners didn't change) + if owner in owners and owners[owner] > 1: + await set_priorities_for_entity( + owner=owner, + owner_group=owner_group, + job_share=entities_shares[owner], + allow_background_tqs=allow_background_tqs, + task_queue_db=task_queue_db, + ) + return + + # Oops the number of owners may have changed so we recalculate the prio for all owners in the group + for owner in owners: + await set_priorities_for_entity( + owner=owner, + owner_group=owner_group, + job_share=entities_shares[owner], + allow_background_tqs=allow_background_tqs, + task_queue_db=task_queue_db, + ) + + +async def set_priorities_for_entity( + owner_group: str, + job_share: float, + allow_background_tqs: bool, + task_queue_db: TaskQueueDB, + owner: str | None = None, +): + """Set the priority for a user/userGroup combo given a splitted share.""" + tq_dict = await task_queue_db.get_task_queue_priorities(owner_group, owner) + if not tq_dict: + return + + rows = await task_queue_db.retrieve_task_queues(list(tq_dict)) + prio_dict = await calculate_priority(tq_dict, rows, job_share, allow_background_tqs) + for prio, tqs in prio_dict.items(): + await task_queue_db.set_priorities_for_entity(tqs, prio) + + +async def calculate_priority( + tq_dict: dict[int, float], + all_tqs_data: dict[int, dict[str, Any]], + share: float, + allow_bg_tqs: bool, +) -> dict[float, list[int]]: + """Calculate the priority for each TQ given a share. + + :param tq_dict: dict of {tq_id: prio} + :param all_tqs_data: dict of {tq_id: {tq_data}}, where tq_data is a dict of {field: value} + :param share: share to be distributed among TQs + :param allow_bg_tqs: allow background TQs to be used + :return: dict of {priority: [tq_ids]} + """ + + def is_background(tq_priority: float, allow_bg_tqs: bool) -> bool: + """A TQ is background if its priority is below a threshold and background TQs are allowed.""" + return tq_priority <= 0.1 and allow_bg_tqs + + # Calculate Sum of priorities of non background TQs + total_prio = sum( + [prio for prio in tq_dict.values() if not is_background(prio, allow_bg_tqs)] + ) + + # Update prio for each TQ + for tq_id, tq_priority in tq_dict.items(): + if is_background(tq_priority, allow_bg_tqs): + prio = TQ_MIN_SHARE + else: + prio = max((share / total_prio) * tq_priority, TQ_MIN_SHARE) + tq_dict[tq_id] = prio + + # Generate groups of TQs that will have the same prio=sum(prios) maomenos + tq_groups: dict[str, list[int]] = defaultdict(list) + for tq_id, tq_data in all_tqs_data.items(): + for field in ("Jobs", "Priority") + PRIORITY_IGNORED_FIELDS: + if field in tq_data: + tq_data.pop(field) + tq_hash = [] + for f in sorted(tq_data): + tq_hash.append(f"{f}:{tq_data[f]}") + tq_hash = "|".join(tq_hash) + # if tq_hash not in tq_groups: + # tq_groups[tq_hash] = [] + tq_groups[tq_hash].append(tq_id) + + # Do the grouping + for tq_group in tq_groups.values(): + total_prio = sum(tq_dict[tq_id] for tq_id in tq_group) + for tq_id in tq_group: + tq_dict[tq_id] = total_prio + + # Group by priorities + result: dict[float, list[int]] = defaultdict(list) + for tq_id, tq_priority in tq_dict.items(): + result[tq_priority].append(tq_id) + + return result diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index d99a51cc..b02bdf31 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -13,12 +13,10 @@ classifiers = [ "Topic :: System :: Distributed Computing", ] dependencies = [ - "aiobotocore>=2.15", "authlib", - "botocore>=1.35", "cachetools", - "dirac", "diracx-core", + "diracx-logic", "diracx-db", "python-dotenv", # TODO: We might not need this "python-multipart", @@ -37,9 +35,6 @@ dynamic = ["version"] [project.optional-dependencies] testing = ["diracx-testing", "moto[server]", "pytest-httpx", "freezegun",] types = [ - "boto3-stubs", - "types-aiobotocore[essential]", - "types-aiobotocore-s3", "types-cachetools", "types-python-dateutil", "types-PyYAML", diff --git a/diracx-routers/src/diracx/routers/access_policies.py b/diracx-routers/src/diracx/routers/access_policies.py index a2bf007b..e19e64c2 100644 --- a/diracx-routers/src/diracx/routers/access_policies.py +++ b/diracx-routers/src/diracx/routers/access_policies.py @@ -27,6 +27,10 @@ from fastapi import Depends from diracx.core.extensions import select_from_extension +from diracx.core.models import ( + AccessTokenPayload, + RefreshTokenPayload, +) from diracx.routers.dependencies import DevelopmentSettings from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token @@ -88,7 +92,9 @@ async def policy(policy_name: str, user_info: AuthorizedUserInfo, /): return @staticmethod - def enrich_tokens(access_payload: dict, refresh_payload: dict) -> tuple[dict, dict]: + def enrich_tokens( + access_payload: AccessTokenPayload, refresh_payload: RefreshTokenPayload + ) -> tuple[dict, dict]: """This method is called when issuing a token, and can add whatever content it wants inside the access or refresh payload. diff --git a/diracx-routers/src/diracx/routers/auth/__init__.py b/diracx-routers/src/diracx/routers/auth/__init__.py index 7d2a93a0..ed71900f 100644 --- a/diracx-routers/src/diracx/routers/auth/__init__.py +++ b/diracx-routers/src/diracx/routers/auth/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from ..fastapi_classes import DiracxRouter -from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from ..utils.users import verify_dirac_access_token from .authorize_code_flow import router as authorize_code_flow_router from .device_flow import router as device_flow_router from .management import router as management_router @@ -14,4 +14,4 @@ router.include_router(authorize_code_flow_router) router.include_router(token_router) -__all__ = ["AuthorizedUserInfo", "has_properties", "verify_dirac_access_token"] +__all__ = ["has_properties", "verify_dirac_access_token"] diff --git a/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py b/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py index b52172ed..2d9ede61 100644 --- a/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py +++ b/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py @@ -43,26 +43,27 @@ status, ) +from diracx.core.exceptions import AuthorizationError, IAMClientError, IAMServerError +from diracx.logic.auth.authorize_code_flow import ( + complete_authorization_flow as complete_authorization_flow_bl, +) +from diracx.logic.auth.authorize_code_flow import ( + initiate_authorization_flow as initiate_authorization_flow_bl, +) + from ..dependencies import ( AuthDB, + AuthSettings, AvailableSecurityProperties, Config, ) from ..fastapi_classes import DiracxRouter -from ..utils.users import AuthSettings -from .utils import ( - GrantType, - decrypt_state, - get_token_from_iam, - initiate_authorization_flow_with_iam, - parse_and_validate_scope, -) router = DiracxRouter(require_auth=False) @router.get("/authorize") -async def authorization_flow( +async def initiate_authorization_flow( request: Request, response_type: Literal["code"], code_challenge: str, @@ -75,7 +76,7 @@ async def authorization_flow( config: Config, available_properties: AvailableSecurityProperties, settings: AuthSettings, -): +) -> responses.RedirectResponse: """Initiate the authorization flow. It will redirect to the actual OpenID server (IAM, CheckIn) to perform a authorization code flow. @@ -95,90 +96,64 @@ async def authorization_flow( to be able to map the authorization flow with the corresponding user authorize flow. """ - if settings.dirac_client_id != client_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID" - ) - if redirect_uri not in settings.allowed_redirects: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised redirect_uri" - ) - try: - parsed_scope = parse_and_validate_scope(scope, config, available_properties) + redirect_uri = await initiate_authorization_flow_bl( + request_url=f"{request.url.replace(query='')}", + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + client_id=client_id, + redirect_uri=redirect_uri, + scope=scope, + state=state, + auth_db=auth_db, + config=config, + settings=settings, + available_properties=available_properties, + ) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=e.args[0], - ) from e - except PermissionError as e: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=e.args[0], + detail=str(e), ) from e - # Store the authorization flow details - uuid = await auth_db.insert_authorization_flow( - client_id, - scope, - code_challenge, - code_challenge_method, - redirect_uri, - ) - - # Initiate the authorization flow with the IAM - state_for_iam = { - "external_state": state, - "uuid": uuid, - "grant_type": GrantType.authorization_code.value, - } - - authorization_flow_url = await initiate_authorization_flow_with_iam( - config, - parsed_scope["vo"], - f"{request.url.replace(query='')}/complete", - state_for_iam, - settings.state_key.fernet, - ) - - return responses.RedirectResponse(authorization_flow_url) + return responses.RedirectResponse(redirect_uri) @router.get("/authorize/complete") -async def authorization_flow_complete( +async def complete_authorization_flow( code: str, state: str, request: Request, auth_db: AuthDB, config: Config, settings: AuthSettings, -): +) -> responses.RedirectResponse: """Complete the authorization flow. The user is redirected back to the DIRAC auth service after completing the IAM's authorization flow. We retrieve the original flow details from the decrypted state and store the ID token requested from the IAM. The user is then redirected to the client's redirect URI. """ - # Decrypt the state to access user details - decrypted_state = decrypt_state(state, settings.state_key.fernet) - assert decrypted_state["grant_type"] == GrantType.authorization_code - - # Get the ID token from the IAM - id_token = await get_token_from_iam( - config, - decrypted_state["vo"], - code, - decrypted_state, - str(request.url.replace(query="")), - ) - - # Store the ID token and redirect the user to the client's redirect URI - code, redirect_uri = await auth_db.authorization_flow_insert_id_token( - decrypted_state["uuid"], - id_token, - settings.authorization_flow_expiration_seconds, - ) - - return responses.RedirectResponse( - f"{redirect_uri}?code={code}&state={decrypted_state['external_state']}" - ) + try: + redirect_uri = await complete_authorization_flow_bl( + code=code, + state=state, + request_url=str(request.url.replace(query="")), + auth_db=auth_db, + config=config, + settings=settings, + ) + except AuthorizationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state" + ) from e + except IAMServerError as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="Failed to contact IAM server", + ) from e + except IAMClientError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid code" + ) from e + return responses.RedirectResponse(redirect_uri) diff --git a/diracx-routers/src/diracx/routers/auth/device_flow.py b/diracx-routers/src/diracx/routers/auth/device_flow.py index 886f1cf2..879dfa59 100644 --- a/diracx-routers/src/diracx/routers/auth/device_flow.py +++ b/diracx-routers/src/diracx/routers/auth/device_flow.py @@ -62,36 +62,28 @@ status, ) from fastapi.responses import RedirectResponse -from typing_extensions import TypedDict + +from diracx.core.exceptions import IAMClientError, IAMServerError +from diracx.core.models import InitiateDeviceFlowResponse +from diracx.logic.auth.device_flow import do_device_flow as do_device_flow_bl +from diracx.logic.auth.device_flow import ( + finish_device_flow as finish_device_flow_bl, +) +from diracx.logic.auth.device_flow import ( + initiate_device_flow as initiate_device_flow_bl, +) from ..dependencies import ( AuthDB, + AuthSettings, AvailableSecurityProperties, Config, ) from ..fastapi_classes import DiracxRouter -from ..utils.users import AuthSettings -from .utils import ( - GrantType, - decrypt_state, - get_token_from_iam, - initiate_authorization_flow_with_iam, - parse_and_validate_scope, -) router = DiracxRouter(require_auth=False) -class InitiateDeviceFlowResponse(TypedDict): - """Response for the device flow initiation.""" - - user_code: str - device_code: str - verification_uri_complete: str - verification_uri: str - expires_in: int - - @router.post("/device") async def initiate_device_flow( client_id: str, @@ -118,35 +110,23 @@ async def initiate_device_flow( Offers the user to go with the browser to `auth//device?user_code=XYZ` """ - if settings.dirac_client_id != client_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognised client ID" - ) - try: - parse_and_validate_scope(scope, config, available_properties) + device_flow_response = await initiate_device_flow_bl( + client_id=client_id, + scope=scope, + verification_uri=str(request.url.replace(query={})), + auth_db=auth_db, + config=config, + available_properties=available_properties, + settings=settings, + ) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=e.args[0], ) from e - except PermissionError as e: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=e.args[0], - ) from e - - user_code, device_code = await auth_db.insert_device_flow(client_id, scope) - - verification_uri = str(request.url.replace(query={})) - return { - "user_code": user_code, - "device_code": device_code, - "verification_uri_complete": f"{verification_uri}?user_code={user_code}", - "verification_uri": str(request.url.replace(query={})), - "expires_in": settings.device_flow_expiration_seconds, - } + return device_flow_response @router.get("/device") @@ -167,25 +147,13 @@ async def do_device_flow( device flow. (note: it can't be put as parameter or in the URL) """ - # Here we make sure the user_code actually exists - scope = await auth_db.device_flow_validate_user_code( - user_code, settings.device_flow_expiration_seconds - ) - parsed_scope = parse_and_validate_scope(scope, config, available_properties) - - redirect_uri = f"{request.url.replace(query='')}/complete" - - state_for_iam = { - "grant_type": GrantType.device_code.value, - "user_code": user_code, - } - - authorization_flow_url = await initiate_authorization_flow_with_iam( - config, - parsed_scope["vo"], - redirect_uri, - state_for_iam, - settings.state_key.fernet, + authorization_flow_url = await do_device_flow_bl( + request_url=str(request.url.replace(query="")), + auth_db=auth_db, + user_code=user_code, + config=config, + available_properties=available_properties, + settings=settings, ) return RedirectResponse(authorization_flow_url) @@ -198,28 +166,36 @@ async def finish_device_flow( auth_db: AuthDB, config: Config, settings: AuthSettings, -): +) -> RedirectResponse: """This the url callbacked by IAM/Checkin after the authorization flow was granted. It gets us the code we need for the authorization flow, and we can map it to the corresponding device flow using the user_code in the cookie/session. """ - decrypted_state = decrypt_state(state, settings.state_key.fernet) - assert decrypted_state["grant_type"] == GrantType.device_code - - id_token = await get_token_from_iam( - config, - decrypted_state["vo"], - code, - decrypted_state, - str(request.url.replace(query="")), - ) - await auth_db.device_flow_insert_id_token( - decrypted_state["user_code"], id_token, settings.device_flow_expiration_seconds - ) + request_url = str(request.url.replace(query={})) + + try: + await finish_device_flow_bl( + request_url, + code, + state, + auth_db, + config, + settings, + ) + except IAMServerError as e: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=e.args[0], + ) from e + except IAMClientError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=e.args[0], + ) from e - return responses.RedirectResponse(f"{request.url.replace(query='')}/finished") + return responses.RedirectResponse(f"{request_url}/finished") @router.get("/device/complete/finished") diff --git a/diracx-routers/src/diracx/routers/auth/management.py b/diracx-routers/src/diracx/routers/auth/management.py index 7a7df038..7fafc20e 100644 --- a/diracx-routers/src/diracx/routers/auth/management.py +++ b/diracx-routers/src/diracx/routers/auth/management.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Annotated, Any +from uuid import UUID from fastapi import ( Depends, @@ -15,7 +16,14 @@ ) from typing_extensions import TypedDict +from diracx.core.exceptions import TokenNotFoundError from diracx.core.properties import PROXY_MANAGEMENT, SecurityProperty +from diracx.logic.auth.management import ( + get_refresh_tokens as get_refresh_tokens_bl, +) +from diracx.logic.auth.management import ( + revoke_refresh_token as revoke_refresh_token_bl, +) from ..dependencies import ( AuthDB, @@ -49,7 +57,7 @@ async def get_refresh_tokens( if PROXY_MANAGEMENT in user_info.properties: subject = None - return await auth_db.get_user_refresh_tokens(subject) + return await get_refresh_tokens_bl(auth_db, subject) @router.delete("/refresh-tokens/{jti}") @@ -61,19 +69,27 @@ async def revoke_refresh_token( """Revoke a refresh token. If the user has the `proxy_management` property, then the subject is not used to filter the refresh tokens. """ - res = await auth_db.get_refresh_token(jti) - if not res: + subject: str | None = user_info.sub + if PROXY_MANAGEMENT in user_info.properties: + subject = None + + try: + await revoke_refresh_token_bl(auth_db, subject, UUID(jti, version=4)) + except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="JTI provided does not exist", - ) - - if PROXY_MANAGEMENT not in user_info.properties and user_info.sub != res["Sub"]: + detail=str(e), + ) from e + except PermissionError as e: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Cannot revoke a refresh token owned by someone else", - ) - await auth_db.revoke_refresh_token(jti) + detail=str(e), + ) from e + except TokenNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e return f"Refresh token {jti} revoked" diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index 1e61b14e..e2cb47cd 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -1,41 +1,74 @@ -"""Token endpoint implementation.""" +"""Token endpoint.""" from __future__ import annotations -import base64 -import hashlib import os -import re -from datetime import timedelta from typing import Annotated, Literal -from uuid import uuid4 -from authlib.jose import JsonWebToken +from authlib.jose import JoseError from fastapi import Depends, Form, Header, HTTPException, status from diracx.core.exceptions import ( DiracHttpResponseError, ExpiredFlowError, + InvalidCredentialsError, PendingAuthorizationError, ) -from diracx.core.models import TokenResponse -from diracx.db.sql.auth.schema import FlowStatus, RefreshTokenStatus +from diracx.core.models import ( + AccessTokenPayload, + GrantType, + RefreshTokenPayload, + TokenResponse, +) +from diracx.logic.auth.token import create_token +from diracx.logic.auth.token import get_oidc_token as get_oidc_token_bl +from diracx.logic.auth.token import ( + perform_legacy_exchange as perform_legacy_exchange_bl, +) from diracx.routers.access_policies import BaseAccessPolicy -from diracx.routers.auth.utils import GrantType -from ..dependencies import AuthDB, AvailableSecurityProperties, Config +from ..dependencies import AuthDB, AuthSettings, AvailableSecurityProperties, Config from ..fastapi_classes import DiracxRouter -from ..utils.users import AuthSettings, get_allowed_user_properties -from .utils import ( - parse_and_validate_scope, - verify_dirac_refresh_token, -) router = DiracxRouter(require_auth=False) +async def mint_token( + access_payload: AccessTokenPayload, + refresh_payload: RefreshTokenPayload, + all_access_policies: dict[str, BaseAccessPolicy], + settings: AuthSettings, +) -> TokenResponse: + """Enrich the token with policy specific content and mint it.""" + # Enrich the token with policy specific content + dirac_access_policies = {} + dirac_refresh_policies = {} + for policy_name, policy in all_access_policies.items(): + + access_extra, refresh_extra = policy.enrich_tokens( + access_payload, refresh_payload + ) + if access_extra: + dirac_access_policies[policy_name] = access_extra + if refresh_extra: + dirac_refresh_policies[policy_name] = refresh_extra + + access_payload["dirac_policies"] = dirac_access_policies + refresh_payload["dirac_policies"] = dirac_refresh_policies + + # Generate the token: encode the payloads + access_token = create_token(access_payload, settings) + refresh_token = create_token(refresh_payload, settings) + + return TokenResponse( + access_token=access_token, + expires_in=settings.access_token_expire_minutes * 60, + refresh_token=refresh_token, + ) + + @router.post("/token") -async def token( +async def get_oidc_token( # Autorest does not support the GrantType annotation # We need to specify each option with Literal[] grant_type: Annotated[ @@ -76,182 +109,54 @@ async def token( """Token endpoint to retrieve the token at the end of a flow. This is the endpoint being pulled by dirac-login when doing the device flow. """ - legacy_exchange = False - - if grant_type == GrantType.device_code: - oidc_token_info, scope = await get_oidc_token_info_from_device_flow( - device_code, client_id, auth_db, settings - ) - - elif grant_type == GrantType.authorization_code: - oidc_token_info, scope = await get_oidc_token_info_from_authorization_flow( - code, client_id, redirect_uri, code_verifier, auth_db, settings - ) - - elif grant_type == GrantType.refresh_token: - ( - oidc_token_info, - scope, - legacy_exchange, - ) = await get_oidc_token_info_from_refresh_flow( - refresh_token, auth_db, settings - ) - else: - raise NotImplementedError(f"Grant type not implemented {grant_type}") - - # Get a TokenResponse to return to the user - return await exchange_token( - auth_db, - scope, - oidc_token_info, - config, - settings, - available_properties, - all_access_policies=all_access_policies, - legacy_exchange=legacy_exchange, - ) - - -async def get_oidc_token_info_from_device_flow( - device_code: str | None, client_id: str, auth_db: AuthDB, settings: AuthSettings -): - """Get OIDC token information from the device flow DB and check few parameters before returning it.""" - assert device_code is not None try: - info = await auth_db.get_device_flow( - device_code, settings.device_flow_expiration_seconds + access_payload, refresh_payload = await get_oidc_token_bl( + grant_type, + client_id, + auth_db, + config, + settings, + available_properties, + device_code=device_code, + code=code, + redirect_uri=redirect_uri, + code_verifier=code_verifier, + refresh_token=refresh_token, ) except PendingAuthorizationError as e: raise DiracHttpResponseError( - status.HTTP_400_BAD_REQUEST, {"error": "authorization_pending"} + status_code=status.HTTP_400_BAD_REQUEST, + data={"error": "authorization_pending"}, + ) from e + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) from e except ExpiredFlowError as e: raise DiracHttpResponseError( - status.HTTP_401_UNAUTHORIZED, {"error": "expired_token"} + status_code=status.HTTP_401_UNAUTHORIZED, + data={"error": "expired_token"}, ) from e - # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "slow_down"}) - # raise DiracHttpResponseError(status.HTTP_400_BAD_REQUEST, {"error": "expired_token"}) - - if info["ClientID"] != client_id: + except InvalidCredentialsError as e: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Bad client_id", - ) - oidc_token_info = info["IDToken"] - scope = info["Scope"] - - # TODO: use HTTPException while still respecting the standard format - # required by the RFC - if info["Status"] != FlowStatus.READY: - # That should never ever happen - raise NotImplementedError(f"Unexpected flow status {info['status']!r}") - return (oidc_token_info, scope) - - -async def get_oidc_token_info_from_authorization_flow( - code: str | None, - client_id: str | None, - redirect_uri: str | None, - code_verifier: str | None, - auth_db: AuthDB, - settings: AuthSettings, -): - """Get OIDC token information from the authorization flow DB and check few parameters before returning it.""" - assert code is not None - info = await auth_db.get_authorization_flow( - code, settings.authorization_flow_expiration_seconds - ) - if redirect_uri != info["RedirectURI"]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid redirect_uri", - ) - if client_id != info["ClientID"]: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Bad client_id", - ) - - # Check the code_verifier - try: - assert code_verifier is not None - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .strip("=") - ) - except Exception as e: + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), + headers={"WWW-Authenticate": "Bearer"}, + ) from e + except JoseError as e: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Malformed code_verifier", + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid JWT: {e}", ) from e - - if code_challenge != info["CodeChallenge"]: + except PermissionError as e: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid code_challenge", - ) - - oidc_token_info = info["IDToken"] - scope = info["Scope"] - - # TODO: use HTTPException while still respecting the standard format - # required by the RFC - if info["Status"] != FlowStatus.READY: - # That should never ever happen - raise NotImplementedError(f"Unexpected flow status {info['status']!r}") - - return (oidc_token_info, scope) - - -async def get_oidc_token_info_from_refresh_flow( - refresh_token: str | None, auth_db: AuthDB, settings: AuthSettings -): - """Get OIDC token information from the refresh token DB and check few parameters before returning it.""" - assert refresh_token is not None - - # Decode the refresh token to get the JWT ID - jti, _, legacy_exchange = await verify_dirac_refresh_token(refresh_token, settings) - - # Get some useful user information from the refresh token entry in the DB - refresh_token_attributes = await auth_db.get_refresh_token(jti) - - sub = refresh_token_attributes["Sub"] - - # Check if the refresh token was obtained from the legacy_exchange endpoint - # If it is the case, we bypass the refresh token rotation mechanism - if not legacy_exchange: - # Refresh token rotation: https://datatracker.ietf.org/doc/html/rfc6749#section-10.4 - # Check that the refresh token has not been already revoked - # This might indicate that a potential attacker try to impersonate someone - # In such case, all the refresh tokens bound to a given user (subject) should be revoked - # Forcing the user to reauthenticate interactively through an authorization/device flow (recommended practice) - if refresh_token_attributes["Status"] == RefreshTokenStatus.REVOKED: - # Revoke all the user tokens from the subject - await auth_db.revoke_user_refresh_tokens(sub) - - # Commit here, otherwise the revokation operation will not be taken into account - # as we return an error to the user - await auth_db.conn.commit() - - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Revoked refresh token reused: potential attack detected. You must authenticate again", - ) - - # Part of the refresh token rotation mechanism: - # Revoke the refresh token provided, a new one needs to be generated - await auth_db.revoke_refresh_token(jti) - - # Build an ID token and get scope from the refresh token attributes received - oidc_token_info = { - # The sub attribute coming from the DB contains the VO name - # We need to remove it as if it were coming from an ID token from an external IdP - "sub": sub.split(":", 1)[1], - "preferred_username": refresh_token_attributes["PreferredUsername"], - } - scope = refresh_token_attributes["Scope"] - return (oidc_token_info, scope, legacy_exchange) + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) from e + return await mint_token( + access_payload, refresh_payload, all_access_policies, settings + ) BASE_64_URL_SAFE_PATTERN = ( @@ -261,7 +166,7 @@ async def get_oidc_token_info_from_refresh_flow( @router.get("/legacy-exchange", include_in_schema=False) -async def legacy_exchange( +async def perform_legacy_exchange( preferred_username: str, scope: str, authorization: Annotated[str, Header()], @@ -273,7 +178,7 @@ async def legacy_exchange( dict[str, BaseAccessPolicy], Depends(BaseAccessPolicy.all_used_access_policies) ], expires_minutes: int | None = None, -): +) -> TokenResponse: """Endpoint used by legacy DIRAC to mint tokens for proxy -> token exchange. This route is disabled if DIRACX_LEGACY_EXCHANGE_HASHED_API_KEY is not set @@ -303,171 +208,34 @@ async def legacy_exchange( detail="Legacy exchange is not enabled", ) - if match := re.fullmatch(LEGACY_EXCHANGE_PATTERN, authorization): - raw_token = base64.urlsafe_b64decode(match.group(1)) - else: + try: + access_payload, refresh_payload = await perform_legacy_exchange_bl( + expected_api_key=expected_api_key, + preferred_username=preferred_username, + scope=scope, + authorization=authorization, + auth_db=auth_db, + available_properties=available_properties, + settings=settings, + config=config, + expires_minutes=expires_minutes, + ) + except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid authorization header", - ) - - if hashlib.sha256(raw_token).hexdigest() != expected_api_key: + detail=str(e), + ) from e + except InvalidCredentialsError as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials", + detail=str(e), headers={"WWW-Authenticate": "Bearer"}, - ) - - try: - parsed_scope = parse_and_validate_scope(scope, config, available_properties) - vo_users = config.Registry[parsed_scope["vo"]] - sub = vo_users.sub_from_preferred_username(preferred_username) - except (KeyError, ValueError) as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid scope or preferred_username", ) from e except PermissionError as e: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=e.args[0], + detail=str(e), ) from e - - return await exchange_token( - auth_db, - scope, - {"sub": sub, "preferred_username": preferred_username}, - config, - settings, - available_properties, - all_access_policies=all_access_policies, - refresh_token_expire_minutes=expires_minutes, - legacy_exchange=True, - ) - - -async def exchange_token( - auth_db: AuthDB, - scope: str, - oidc_token_info: dict, - config: Config, - settings: AuthSettings, - available_properties: AvailableSecurityProperties, - all_access_policies: Annotated[ - dict[str, BaseAccessPolicy], Depends(BaseAccessPolicy.all_used_access_policies) - ], - *, - refresh_token_expire_minutes: int | None = None, - legacy_exchange: bool = False, -) -> TokenResponse: - """Method called to exchange the OIDC token for a DIRAC generated access token.""" - # Extract dirac attributes from the OIDC scope - try: - parsed_scope = parse_and_validate_scope(scope, config, available_properties) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=e.args[0], - ) from e - vo = parsed_scope["vo"] - dirac_group = parsed_scope["group"] - properties = parsed_scope["properties"] - - # Extract attributes from the OIDC token details - sub = oidc_token_info["sub"] - if user_info := config.Registry[vo].Users.get(sub): - preferred_username = user_info.PreferedUsername - else: - preferred_username = oidc_token_info.get("preferred_username", sub) - raise NotImplementedError( - "Dynamic registration of users is not yet implemented" - ) - - # Extract attributes from the settings and configuration - issuer = settings.token_issuer - - # Check that the subject is part of the dirac users - if sub not in config.Registry[vo].Groups[dirac_group].Users: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"User is not a member of the requested group ({preferred_username}, {dirac_group})", - ) - - # Check that the user properties are valid - allowed_user_properties = get_allowed_user_properties(config, sub, vo) - if not properties.issubset(allowed_user_properties): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"{' '.join(properties - allowed_user_properties)} are not valid properties " - f"for user {preferred_username}, available values: {' '.join(allowed_user_properties)}", - ) - - # Merge the VO with the subject to get a unique DIRAC sub - sub = f"{vo}:{sub}" - - # Insert the refresh token with user details into the RefreshTokens table - # User details are needed to regenerate access tokens later - jti, creation_time = await auth_db.insert_refresh_token( - subject=sub, - preferred_username=preferred_username, - scope=scope, - ) - - # Generate refresh token payload - if refresh_token_expire_minutes is None: - refresh_token_expire_minutes = settings.refresh_token_expire_minutes - refresh_payload = { - "jti": jti, - "exp": creation_time + timedelta(minutes=refresh_token_expire_minutes), - # legacy_exchange is used to indicate that the original refresh token - # was obtained from the legacy_exchange endpoint - "legacy_exchange": legacy_exchange, - } - - # Generate access token payload - # For now, the access token is only used to access DIRAC services, - # therefore, the audience is not set and checked - access_payload = { - "sub": sub, - "vo": vo, - "iss": issuer, - "dirac_properties": list(properties), - "jti": str(uuid4()), - "preferred_username": preferred_username, - "dirac_group": dirac_group, - "exp": creation_time + timedelta(minutes=settings.access_token_expire_minutes), - } - - # Enrich the token payload with policy specific content - dirac_access_policies = {} - dirac_refresh_policies = {} - for policy_name, policy in all_access_policies.items(): - - access_extra, refresh_extra = policy.enrich_tokens( - access_payload, refresh_payload - ) - if access_extra: - dirac_access_policies[policy_name] = access_extra - if refresh_extra: - dirac_refresh_policies[policy_name] = refresh_extra - - access_payload["dirac_policies"] = dirac_access_policies - refresh_payload["dirac_policies"] = dirac_refresh_policies - - # Generate the token: encode the payloads - access_token = create_token(access_payload, settings) - refresh_token = create_token(refresh_payload, settings) - - return TokenResponse( - access_token=access_token, - expires_in=settings.access_token_expire_minutes * 60, - refresh_token=refresh_token, - ) - - -def create_token(payload: dict, settings: AuthSettings) -> str: - jwt = JsonWebToken(settings.token_algorithm) - encoded_jwt = jwt.encode( - {"alg": settings.token_algorithm}, payload, settings.token_key.jwk + return await mint_token( + access_payload, refresh_payload, all_access_policies, settings ) - return encoded_jwt.decode("ascii") diff --git a/diracx-routers/src/diracx/routers/auth/utils.py b/diracx-routers/src/diracx/routers/auth/utils.py index c0a37a07..57473b2c 100644 --- a/diracx-routers/src/diracx/routers/auth/utils.py +++ b/diracx-routers/src/diracx/routers/auth/utils.py @@ -1,45 +1,14 @@ from __future__ import annotations -import base64 -import hashlib -import json -import secrets -from enum import StrEnum -from typing import Annotated, TypedDict +from typing import Annotated -import httpx -from authlib.integrations.starlette_client import OAuthError -from authlib.jose import JoseError, JsonWebKey, JsonWebToken -from authlib.oidc.core import IDToken -from cachetools import TTLCache -from cryptography.fernet import Fernet from fastapi import Depends, HTTPException, status from diracx.core.properties import ( SecurityProperty, UnevaluatedProperty, ) -from diracx.routers.utils.users import ( - AuthorizedUserInfo, - AuthSettings, - verify_dirac_access_token, -) - -from ..dependencies import Config - - -class GrantType(StrEnum): - """Grant types for OAuth2.""" - - authorization_code = "authorization_code" - device_code = "urn:ietf:params:oauth:grant-type:device_code" - refresh_token = "refresh_token" # noqa: S105 # False positive of Bandit about hard coded password - - -class ScopeInfoDict(TypedDict): - group: str - properties: set[str] - vo: str +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token def has_properties(expression: UnevaluatedProperty | SecurityProperty): @@ -57,266 +26,3 @@ async def require_property( raise HTTPException(status.HTTP_403_FORBIDDEN) return Depends(require_property) - - -_server_metadata_cache: TTLCache = TTLCache(maxsize=1024, ttl=3600) - - -async def get_server_metadata(url: str): - """Get the server metadata from the IAM.""" - server_metadata = _server_metadata_cache.get(url) - if server_metadata is None: - async with httpx.AsyncClient() as c: - res = await c.get(url) - if res.status_code != 200: - # TODO: Better error handling - raise NotImplementedError(res) - server_metadata = res.json() - _server_metadata_cache[url] = server_metadata - return server_metadata - - -async def fetch_jwk_set(url: str): - """Fetch the JWK set from the IAM.""" - server_metadata = await get_server_metadata(url) - - jwks_uri = server_metadata.get("jwks_uri") - if not jwks_uri: - raise RuntimeError('Missing "jwks_uri" in metadata') - - async with httpx.AsyncClient() as c: - res = await c.get(jwks_uri) - if res.status_code != 200: - # TODO: Better error handling - raise NotImplementedError(res) - jwk_set = res.json() - - # self.server_metadata['jwks'] = jwk_set - return JsonWebKey.import_key_set(jwk_set) - - -async def parse_id_token(config, vo, raw_id_token: str): - """Parse and validate the ID token from IAM.""" - server_metadata = await get_server_metadata( - config.Registry[vo].IdP.server_metadata_url - ) - alg_values = server_metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwk_set = await fetch_jwk_set(config.Registry[vo].IdP.server_metadata_url) - - token = JsonWebToken(alg_values).decode( - raw_id_token, - key=jwk_set, - claims_cls=IDToken, - claims_options={ - "iss": {"values": [server_metadata["issuer"]]}, - # The audience is a required parameter and is the client ID of the application - # https://openid.net/specs/openid-connect-core-1_0.html#IDToken - "aud": {"values": [config.Registry[vo].IdP.ClientID]}, - }, - ) - token.validate() - return token - - -def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str: - """Encrypt the state dict and return it as a string.""" - return cipher_suite.encrypt( - base64.urlsafe_b64encode(json.dumps(state_dict).encode()) - ).decode() - - -def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]: - """Decrypt the state string and return it as a dict.""" - try: - return json.loads( - base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode() - ) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state" - ) from e - - -async def verify_dirac_refresh_token( - refresh_token: str, - settings: AuthSettings, -) -> tuple[str, float, bool]: - """Verify dirac user token and return a UserInfo class - Used for each API endpoint. - """ - try: - jwt = JsonWebToken(settings.token_algorithm) - token = jwt.decode( - refresh_token, - key=settings.token_key.jwk, - ) - token.validate() - # Handle problematic tokens such as: - # - tokens signed with an invalid JWK - # - expired tokens - except JoseError as e: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=f"Invalid JWT: {e.args[0]}", - headers={"WWW-Authenticate": "Bearer"}, - ) from e - - return (token["jti"], float(token["exp"]), token["legacy_exchange"]) - - -def parse_and_validate_scope( - scope: str, config: Config, available_properties: set[SecurityProperty] -) -> ScopeInfoDict: - """Check: - * At most one VO - * At most one group - * group belongs to VO - * properties are known - return dict with group and properties. - - :raises: - * ValueError in case the scope isn't valide - """ - scopes = set(scope.split(" ")) - - groups = [] - properties = [] - vos = [] - unrecognised = [] - for scope in scopes: - if scope.startswith("group:"): - groups.append(scope.split(":", 1)[1]) - elif scope.startswith("property:"): - properties.append(scope.split(":", 1)[1]) - elif scope.startswith("vo:"): - vos.append(scope.split(":", 1)[1]) - else: - unrecognised.append(scope) - if unrecognised: - raise ValueError(f"Unrecognised scopes: {unrecognised}") - - if not vos: - available_vo_scopes = [repr(f"vo:{vo}") for vo in config.Registry] - raise ValueError( - "No vo scope requested, available values: " - f"{' '.join(available_vo_scopes)}" - ) - elif len(vos) > 1: - raise ValueError(f"Only one vo is allowed but got {vos}") - else: - vo = vos[0] - if vo not in config.Registry: - raise ValueError(f"VO {vo} is not known to this installation") - - if not groups: - # TODO: Handle multiple groups correctly - group = config.Registry[vo].DefaultGroup - elif len(groups) > 1: - raise ValueError(f"Only one DIRAC group allowed but got {groups}") - else: - group = groups[0] - if group not in config.Registry[vo].Groups: - raise ValueError(f"{group} not in {vo} groups") - - allowed_properties = config.Registry[vo].Groups[group].Properties - properties.extend([str(p) for p in allowed_properties]) - - if not set(properties).issubset(available_properties): - raise ValueError( - f"{set(properties)-set(available_properties)} are not valid properties" - ) - - return { - "group": group, - "properties": set(sorted(properties)), - "vo": vo, - } - - -async def initiate_authorization_flow_with_iam( - config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet -): - """Initiate the authorization flow with the IAM. Return the URL to redirect the user to. - - The state dict is encrypted and passed to the IAM. - It is then decrypted when the user is redirected back to the redirect_uri. - """ - # code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1 - code_verifier = secrets.token_hex() - - # code_challenge: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .replace("=", "") - ) - - server_metadata = await get_server_metadata( - config.Registry[vo].IdP.server_metadata_url - ) - - # Take these two from CS/.well-known - authorization_endpoint = server_metadata["authorization_endpoint"] - - # Encrypt the state and pass it to the IAM - # Needed to retrieve the original flow details when the user is redirected back to the redirect_uri - encrypted_state = encrypt_state( - state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite - ) - - url_params = [ - "response_type=code", - f"code_challenge={code_challenge}", - "code_challenge_method=S256", - f"client_id={config.Registry[vo].IdP.ClientID}", - f"redirect_uri={redirect_uri}", - "scope=openid%20profile", - f"state={encrypted_state}", - ] - authorization_flow_url = f"{authorization_endpoint}?{'&'.join(url_params)}" - return authorization_flow_url - - -async def get_token_from_iam( - config, vo: str, code: str, state: dict[str, str], redirect_uri: str -) -> dict[str, str]: - """Get the token from the IAM using the code and state. Return the ID token.""" - server_metadata = await get_server_metadata( - config.Registry[vo].IdP.server_metadata_url - ) - - # Take these two from CS/.well-known - token_endpoint = server_metadata["token_endpoint"] - - data = { - "grant_type": GrantType.authorization_code.value, - "client_id": config.Registry[vo].IdP.ClientID, - "code": code, - "code_verifier": state["code_verifier"], - "redirect_uri": redirect_uri, - } - - async with httpx.AsyncClient() as c: - res = await c.post( - token_endpoint, - data=data, - ) - if res.status_code >= 500: - raise HTTPException( - status.HTTP_502_BAD_GATEWAY, "Failed to contact token endpoint" - ) - elif res.status_code >= 400: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid code") - - raw_id_token = res.json()["id_token"] - # Extract the payload and verify it - try: - id_token = await parse_id_token( - config=config, - vo=vo, - raw_id_token=raw_id_token, - ) - except OAuthError: - raise - - return id_token diff --git a/diracx-routers/src/diracx/routers/auth/well_known.py b/diracx-routers/src/diracx/routers/auth/well_known.py index 94b6d11a..3002aa2d 100644 --- a/diracx-routers/src/diracx/routers/auth/well_known.py +++ b/diracx-routers/src/diracx/routers/auth/well_known.py @@ -1,95 +1,41 @@ from __future__ import annotations from fastapi import Request -from typing_extensions import TypedDict -from ..dependencies import Config, DevelopmentSettings +from diracx.core.models import Metadata, OpenIDConfiguration +from diracx.logic.auth.well_known import ( + get_installation_metadata as get_installation_metadata_bl, +) +from diracx.logic.auth.well_known import ( + get_openid_configuration as get_openid_configuration_bl, +) + +from ..dependencies import AuthSettings, Config from ..fastapi_classes import DiracxRouter -from ..utils.users import AuthSettings router = DiracxRouter(require_auth=False, path_root="") @router.get("/openid-configuration") -async def openid_configuration( +async def get_openid_configuration( request: Request, config: Config, settings: AuthSettings, -): +) -> OpenIDConfiguration: """OpenID Connect discovery endpoint.""" - # await check_permissions() - scopes_supported = [] - for vo in config.Registry: - scopes_supported.append(f"vo:{vo}") - scopes_supported += [f"group:{vo}" for vo in config.Registry[vo].Groups] - scopes_supported += [f"property:{p}" for p in settings.available_properties] - - return { - "issuer": settings.token_issuer, - "token_endpoint": str(request.url_for("token")), - "userinfo_endpoint:": str(request.url_for("userinfo")), - "authorization_endpoint": str(request.url_for("authorization_flow")), - "device_authorization_endpoint": str(request.url_for("initiate_device_flow")), - # "introspection_endpoint":"", - # "userinfo_endpoint":"", - "grant_types_supported": [ - "authorization_code", - "urn:ietf:params:oauth:grant-type:device_code", - ], - "scopes_supported": scopes_supported, - "response_types_supported": ["code"], - "token_endpoint_auth_signing_alg_values_supported": [settings.token_algorithm], - "token_endpoint_auth_methods_supported": ["none"], - "code_challenge_methods_supported": ["S256"], - } - - -class SupportInfo(TypedDict): - message: str - webpage: str | None - email: str | None - - -class GroupInfo(TypedDict): - properties: list[str] - - -class VOInfo(TypedDict): - groups: dict[str, GroupInfo] - support: SupportInfo - default_group: str - - -class Metadata(TypedDict): - virtual_organizations: dict[str, VOInfo] - development_settings: DevelopmentSettings + return await get_openid_configuration_bl( + str(request.url_for("get_oidc_token")), + str(request.url_for("userinfo")), + str(request.url_for("initiate_authorization_flow")), + str(request.url_for("initiate_device_flow")), + config, + settings, + ) @router.get("/dirac-metadata") -async def installation_metadata( +async def get_installation_metadata( config: Config, - # check_permissions: OpenAccessPolicyCallable, - dev_settings: DevelopmentSettings, ) -> Metadata: """Get metadata about the dirac installation.""" - # await check_permissions() - metadata: Metadata = { - "virtual_organizations": {}, - "development_settings": dev_settings, - } - for vo, vo_info in config.Registry.items(): - groups: dict[str, GroupInfo] = { - group: {"properties": sorted(group_info.Properties)} - for group, group_info in vo_info.Groups.items() - } - metadata["virtual_organizations"][vo] = { - "groups": groups, - "support": { - "message": vo_info.Support.Message, - "webpage": vo_info.Support.Webpage, - "email": vo_info.Support.Email, - }, - "default_group": vo_info.DefaultGroup, - } - - return metadata + return await get_installation_metadata_bl(config) diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index ab40190b..8eb2bd26 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -19,7 +19,9 @@ from diracx.core.config import Config as _Config from diracx.core.config import ConfigSource from diracx.core.properties import SecurityProperty +from diracx.core.settings import AuthSettings as _AuthSettings from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings +from diracx.core.settings import SandboxStoreSettings as _SandboxStoreSettings from diracx.db.os import JobParametersDB as _JobParametersDB from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB @@ -56,6 +58,10 @@ def add_settings_annotation(cls: T) -> T: set[SecurityProperty], Depends(SecurityProperty.available_properties) ] +AuthSettings = Annotated[_AuthSettings, Depends(_AuthSettings.create)] DevelopmentSettings = Annotated[ _DevelopmentSettings, Depends(_DevelopmentSettings.create) ] +SandboxStoreSettings = Annotated[ + _SandboxStoreSettings, Depends(_SandboxStoreSettings.create) +] diff --git a/diracx-routers/src/diracx/routers/jobs/access_policies.py b/diracx-routers/src/diracx/routers/jobs/access_policies.py index 00b012e3..3fe2330b 100644 --- a/diracx-routers/src/diracx/routers/jobs/access_policies.py +++ b/diracx-routers/src/diracx/routers/jobs/access_policies.py @@ -7,10 +7,9 @@ from fastapi import Depends, HTTPException, status from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER -from diracx.db.sql import JobDB, SandboxMetadataDB +from diracx.db.sql import JobDB from diracx.routers.access_policies import BaseAccessPolicy - -from ..utils.users import AuthorizedUserInfo +from diracx.routers.utils.users import AuthorizedUserInfo class ActionType(StrEnum): @@ -109,13 +108,10 @@ async def policy( /, *, action: ActionType | None = None, - sandbox_metadata_db: SandboxMetadataDB | None = None, pfns: list[str] | None = None, required_prefix: str | None = None, ): assert action, "action is a mandatory parameter" - assert sandbox_metadata_db, "sandbox_metadata_db is a mandatory parameter" - assert pfns, "pfns is a mandatory parameter" if action == ActionType.CREATE: @@ -130,14 +126,17 @@ async def policy( raise HTTPException(status.HTTP_403_FORBIDDEN) # Getting a sandbox or modifying it - if required_prefix is None: - raise NotImplementedError("required_prefix is None. his shouldn't happen") - for pfn in pfns: - if not pfn.startswith(required_prefix): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Invalid PFN. PFN must start with {required_prefix}", + if pfns: + if required_prefix is None: + raise NotImplementedError( + "required_prefix is None. This shouldn't happen" ) + for pfn in pfns: + if not pfn.startswith(required_prefix): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Invalid PFN. PFN must start with {required_prefix}", + ) CheckSandboxPolicyCallable = Annotated[Callable, Depends(SandboxAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/jobs/query.py b/diracx-routers/src/diracx/routers/jobs/query.py index 591bb736..8e5be09f 100644 --- a/diracx-routers/src/diracx/routers/jobs/query.py +++ b/diracx-routers/src/diracx/routers/jobs/query.py @@ -1,17 +1,16 @@ from __future__ import annotations -import logging from http import HTTPStatus from typing import Annotated, Any from fastapi import Body, Depends, Response -from pydantic import BaseModel from diracx.core.models import ( - ScalarSearchOperator, - SearchSpec, - SortSpec, + JobSearchParams, + JobSummaryParams, ) +from diracx.logic.jobs.query import search as search_bl +from diracx.logic.jobs.query import summary as summary_bl from ..dependencies import ( Config, @@ -23,25 +22,9 @@ from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token from .access_policies import ActionType, CheckWMSPolicyCallable -logger = logging.getLogger(__name__) - router = DiracxRouter() -class JobSummaryParams(BaseModel): - grouping: list[str] - search: list[SearchSpec] = [] - # TODO: Add more validation - - -class JobSearchParams(BaseModel): - parameters: list[str] | None = None - search: list[SearchSpec] = [] - sort: list[SortSpec] = [] - distinct: bool = False - # TODO: Add more validation - - MAX_PER_PAGE = 10000 @@ -160,48 +143,17 @@ async def search( """ await check_permissions(action=ActionType.QUERY, job_db=job_db) - # Apply a limit to per_page to prevent abuse of the API - if per_page > MAX_PER_PAGE: - per_page = MAX_PER_PAGE - - if body is None: - body = JobSearchParams() - - if query_logging_info := ("LoggingInfo" in (body.parameters or [])): - if body.parameters: - body.parameters.remove("LoggingInfo") - body.parameters = ["JobID"] + (body.parameters or []) - - # TODO: Apply all the job policy stuff properly using user_info - if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: - body.search.append( - { - "parameter": "Owner", - "operator": ScalarSearchOperator.EQUAL, - # TODO-385: https://github.com/DIRACGrid/diracx/issues/385 - # The value shoud be user_info.sub, - # but since we historically rely on the preferred_username - # we will keep using the preferred_username for now. - "value": user_info.preferred_username, - } - ) - - total, jobs = await job_db.search( - body.parameters, - body.search, - body.sort, - distinct=body.distinct, + total, jobs = await search_bl( + config=config, + job_db=job_db, + job_parameters_db=job_parameters_db, + job_logging_db=job_logging_db, + preferred_username=user_info.preferred_username, page=page, per_page=per_page, + body=body, ) - if query_logging_info: - job_logging_info = await job_logging_db.get_records( - [job["JobID"] for job in jobs] - ) - for job in jobs: - job.update({"LoggingInfo": job_logging_info[job["JobID"]]}) - # Set the Content-Range header if needed # https://datatracker.ietf.org/doc/html/rfc7233#section-4 @@ -231,13 +183,10 @@ async def summary( ): """Show information suitable for plotting.""" await check_permissions(action=ActionType.QUERY, job_db=job_db) - # TODO: Apply all the job policy stuff properly using user_info - if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: - body.search.append( - { - "parameter": "Owner", - "operator": ScalarSearchOperator.EQUAL, - "value": user_info.sub, - } - ) - return await job_db.summary(body.grouping, body.search) + + return await summary_bl( + config=config, + job_db=job_db, + preferred_username=user_info.preferred_username, + body=body, + ) diff --git a/diracx-routers/src/diracx/routers/jobs/sandboxes.py b/diracx-routers/src/diracx/routers/jobs/sandboxes.py index 01441032..09b2a6d9 100644 --- a/diracx-routers/src/diracx/routers/jobs/sandboxes.py +++ b/diracx-routers/src/diracx/routers/jobs/sandboxes.py @@ -1,30 +1,33 @@ from __future__ import annotations -import contextlib -from collections.abc import AsyncIterator from http import HTTPStatus -from typing import TYPE_CHECKING, Annotated, Literal +from typing import Annotated, Literal -from aiobotocore.session import get_session -from botocore.config import Config -from botocore.errorfactory import ClientError from fastapi import Body, Depends, HTTPException, Query -from pydantic import BaseModel, PrivateAttr -from pydantic_settings import SettingsConfigDict from pyparsing import Any from diracx.core.exceptions import SandboxAlreadyAssignedError, SandboxNotFoundError from diracx.core.models import ( + SandboxDownloadResponse, SandboxInfo, - SandboxType, + SandboxUploadResponse, ) -from diracx.core.s3 import ( - generate_presigned_upload, - s3_bucket_exists, - s3_object_exists, +from diracx.logic.jobs.sandboxes import SANDBOX_PFN_REGEX +from diracx.logic.jobs.sandboxes import ( + assign_sandbox_to_job as assign_sandbox_to_job_bl, +) +from diracx.logic.jobs.sandboxes import get_job_sandbox as get_job_sandbox_bl +from diracx.logic.jobs.sandboxes import get_job_sandboxes as get_job_sandboxes_bl +from diracx.logic.jobs.sandboxes import get_sandbox_file as get_sandbox_file_bl +from diracx.logic.jobs.sandboxes import ( + initiate_sandbox_upload as initiate_sandbox_upload_bl, +) +from diracx.logic.jobs.sandboxes import ( + unassign_jobs_sandboxes as unassign_jobs_sandboxes_bl, ) -from diracx.core.settings import ServiceSettingsBase +from ..dependencies import JobDB, SandboxMetadataDB, SandboxStoreSettings +from ..fastapi_classes import DiracxRouter from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token from .access_policies import ( ActionType, @@ -32,63 +35,10 @@ CheckWMSPolicyCallable, ) -if TYPE_CHECKING: - from types_aiobotocore_s3.client import S3Client - -from ..dependencies import JobDB, SandboxMetadataDB, add_settings_annotation -from ..fastapi_classes import DiracxRouter - MAX_SANDBOX_SIZE_BYTES = 100 * 1024 * 1024 router = DiracxRouter() -@add_settings_annotation -class SandboxStoreSettings(ServiceSettingsBase): - """Settings for the sandbox store.""" - - model_config = SettingsConfigDict(env_prefix="DIRACX_SANDBOX_STORE_") - - bucket_name: str - s3_client_kwargs: dict[str, str] - auto_create_bucket: bool = False - url_validity_seconds: int = 5 * 60 - se_name: str = "SandboxSE" - _client: S3Client = PrivateAttr() - - @contextlib.asynccontextmanager - async def lifetime_function(self) -> AsyncIterator[None]: - async with get_session().create_client( - "s3", - **self.s3_client_kwargs, - config=Config(signature_version="v4"), - ) as self._client: # type: ignore - if not await s3_bucket_exists(self._client, self.bucket_name): - if not self.auto_create_bucket: - raise ValueError( - f"Bucket {self.bucket_name} does not exist and auto_create_bucket is disabled" - ) - try: - await self._client.create_bucket(Bucket=self.bucket_name) - except ClientError as e: - raise ValueError( - f"Failed to create bucket {self.bucket_name}" - ) from e - - yield - - @property - def s3_client(self) -> S3Client: - if self._client is None: - raise RuntimeError("S3 client accessed before lifetime function") - return self._client - - -class SandboxUploadResponse(BaseModel): - pfn: str - url: str | None = None - fields: dict[str, str] = {} - - @router.post("/sandbox") async def initiate_sandbox_upload( user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], @@ -105,78 +55,18 @@ async def initiate_sandbox_upload( If the sandbox does not exist in the database then the "url" and "fields" should be used to upload the sandbox to the storage backend. """ - pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info) - full_pfn = f"SB:{settings.se_name}|{pfn}" - await check_permissions( - action=ActionType.CREATE, sandbox_metadata_db=sandbox_metadata_db, pfns=[pfn] - ) - - # TODO: THis test should come first, but if we do - # the access policy will crash for not having been called - # so we need to find a way to ackownledge that - - if sandbox_info.size > MAX_SANDBOX_SIZE_BYTES: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Sandbox too large. Max size is {MAX_SANDBOX_SIZE_BYTES} bytes", - ) + await check_permissions(action=ActionType.CREATE) try: - exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned( - pfn, settings.se_name + sandbox_upload_response = await initiate_sandbox_upload_bl( + user_info, sandbox_info, sandbox_metadata_db, settings ) - except SandboxNotFoundError: - # The sandbox doesn't exist in the database - pass - else: - # As sandboxes are registered in the DB before uploading to the storage - # backend we can't rely on their existence in the database to determine if - # they have been uploaded. Instead we check if the sandbox has been - # assigned to a job. If it has then we know it has been uploaded and we - # can avoid communicating with the storage backend. - if exists_and_assigned or s3_object_exists( - settings.s3_client, settings.bucket_name, pfn_to_key(pfn) - ): - await sandbox_metadata_db.update_sandbox_last_access_time( - settings.se_name, pfn - ) - return SandboxUploadResponse(pfn=full_pfn) - - upload_info = await generate_presigned_upload( - settings.s3_client, - settings.bucket_name, - pfn_to_key(pfn), - sandbox_info.checksum_algorithm, - sandbox_info.checksum, - sandbox_info.size, - settings.url_validity_seconds, - ) - await sandbox_metadata_db.insert_sandbox( - settings.se_name, user_info, pfn, sandbox_info.size - ) - - return SandboxUploadResponse(**upload_info, pfn=full_pfn) - - -class SandboxDownloadResponse(BaseModel): - url: str - expires_in: int - - -def pfn_to_key(pfn: str) -> str: - """Convert a PFN to a key for S3. - - This removes the leading "/S3/" from the PFN. - """ - return "/".join(pfn.split("/")[3:]) - - -SANDBOX_PFN_REGEX = ( - # Starts with /S3/ or /SB:|/S3/ - r"^(:?SB:[A-Za-z]+\|)?/S3/[a-z0-9\.\-]{3,63}" - # Followed ////:. - r"(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\.[a-z0-9\.]+$" -) + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=str(e), + ) from e + return sandbox_upload_response @router.get("/sandbox") @@ -184,7 +74,6 @@ async def get_sandbox_file( pfn: Annotated[str, Query(max_length=256, pattern=SANDBOX_PFN_REGEX)], settings: SandboxStoreSettings, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - sandbox_metadata_db: SandboxMetadataDB, check_permissions: CheckSandboxPolicyCallable, ) -> SandboxDownloadResponse: """Get a presigned URL to download a sandbox file. @@ -195,7 +84,7 @@ async def get_sandbox_file( most storage backends return an error when they receive an authorization header for a presigned URL. """ - pfn = pfn.split("|", 1)[-1] + short_pfn = pfn.split("|", 1)[-1] required_prefix = ( "/" + f"S3/{settings.bucket_name}/{user_info.vo}/{user_info.dirac_group}/{user_info.preferred_username}" @@ -203,20 +92,11 @@ async def get_sandbox_file( ) await check_permissions( action=ActionType.READ, - sandbox_metadata_db=sandbox_metadata_db, - pfns=[pfn], + pfns=[short_pfn], required_prefix=required_prefix, ) - # TODO: Support by name and by job id? - presigned_url = await settings.s3_client.generate_presigned_url( - ClientMethod="get_object", - Params={"Bucket": settings.bucket_name, "Key": pfn_to_key(pfn)}, - ExpiresIn=settings.url_validity_seconds, - ) - return SandboxDownloadResponse( - url=presigned_url, expires_in=settings.url_validity_seconds - ) + return await get_sandbox_file_bl(pfn, settings) @router.get("/{job_id}/sandbox") @@ -228,14 +108,7 @@ async def get_job_sandboxes( ) -> dict[str, list[Any]]: """Get input and output sandboxes of given job.""" await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) - - input_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job( - job_id, SandboxType.Input - ) - output_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job( - job_id, SandboxType.Output - ) - return {SandboxType.Input: input_sb, SandboxType.Output: output_sb} + return await get_job_sandboxes_bl(job_id, sandbox_metadata_db) @router.get("/{job_id}/sandbox/{sandbox_type}") @@ -248,11 +121,7 @@ async def get_job_sandbox( ) -> list[Any]: """Get input or output sandbox of given job.""" await check_permissions(action=ActionType.READ, job_db=job_db, job_ids=[job_id]) - job_sb_pfns = await sandbox_metadata_db.get_sandbox_assigned_to_job( - job_id, SandboxType(sandbox_type.capitalize()) - ) - - return job_sb_pfns + return await get_job_sandbox_bl(job_id, sandbox_metadata_db, sandbox_type) @router.patch("/{job_id}/sandbox/output") @@ -266,14 +135,8 @@ async def assign_sandbox_to_job( ): """Map the pfn as output sandbox to job.""" await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - short_pfn = pfn.split("|", 1)[-1] try: - await sandbox_metadata_db.assign_sandbox_to_jobs( - jobs_ids=[job_id], - pfn=short_pfn, - sb_type=SandboxType.Output, - se_name=settings.se_name, - ) + await assign_sandbox_to_job_bl(job_id, pfn, sandbox_metadata_db, settings) except SandboxNotFoundError as e: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="Sandbox not found" @@ -293,7 +156,7 @@ async def unassign_job_sandboxes( ): """Delete single job sandbox mapping.""" await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=[job_id]) - await sandbox_metadata_db.unassign_sandboxes_to_jobs([job_id]) + await unassign_jobs_sandboxes_bl([job_id], sandbox_metadata_db) @router.delete("/sandbox") @@ -305,4 +168,4 @@ async def unassign_bulk_jobs_sandboxes( ): """Delete bulk jobs sandbox mapping.""" await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=jobs_ids) - await sandbox_metadata_db.unassign_sandboxes_to_jobs(jobs_ids) + await unassign_jobs_sandboxes_bl(jobs_ids, sandbox_metadata_db) diff --git a/diracx-routers/src/diracx/routers/jobs/status.py b/diracx-routers/src/diracx/routers/jobs/status.py index ab9048ee..d7e3e34f 100644 --- a/diracx-routers/src/diracx/routers/jobs/status.py +++ b/diracx-routers/src/diracx/routers/jobs/status.py @@ -1,21 +1,18 @@ from __future__ import annotations -import logging from datetime import datetime from http import HTTPStatus -from typing import Annotated +from typing import Annotated, Any -from fastapi import BackgroundTasks, HTTPException, Query +from fastapi import HTTPException, Query from diracx.core.models import ( JobStatusUpdate, SetJobStatusReturn, ) -from diracx.db.sql.utils.job import ( - remove_jobs, - reschedule_jobs_bulk, - set_job_status_bulk, -) +from diracx.logic.jobs.status import remove_jobs as remove_jobs_bl +from diracx.logic.jobs.status import reschedule_jobs as reschedule_jobs_bl +from diracx.logic.jobs.status import set_job_statuses as set_job_statuses_bl from ..dependencies import ( Config, @@ -27,20 +24,17 @@ from ..fastapi_classes import DiracxRouter from .access_policies import ActionType, CheckWMSPolicyCallable -logger = logging.getLogger(__name__) - router = DiracxRouter() @router.delete("/") -async def remove_bulk_jobs( +async def remove_jobs( job_ids: Annotated[list[int], Query()], config: Config, job_db: JobDB, job_logging_db: JobLoggingDB, sandbox_metadata_db: SandboxMetadataDB, task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, check_permissions: CheckWMSPolicyCallable, ): """Fully remove a list of jobs from the WMS databases. @@ -51,14 +45,13 @@ async def remove_bulk_jobs( """ await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) - return await remove_jobs( + return await remove_jobs_bl( job_ids, config, job_db, job_logging_db, sandbox_metadata_db, task_queue_db, - background_task, ) @@ -69,30 +62,28 @@ async def set_job_statuses( job_db: JobDB, job_logging_db: JobLoggingDB, task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, check_permissions: CheckWMSPolicyCallable, force: bool = False, ) -> SetJobStatusReturn: await check_permissions( action=ActionType.MANAGE, job_db=job_db, job_ids=list(job_update) ) - # check that the datetime contains timezone info - for job_id, status in job_update.items(): - for dt in status: - if dt.tzinfo is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Timestamp {dt} is not timezone aware for job {job_id}", - ) - result = await set_job_status_bulk( - job_update, - config, - job_db, - job_logging_db, - task_queue_db, - background_task, - force=force, - ) + + try: + result = await set_job_statuses_bl( + status_changes=job_update, + config=config, + job_db=job_db, + job_logging_db=job_logging_db, + task_queue_db=task_queue_db, + force=force, + ) + except ValueError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=str(e), + ) from e + if not result.success: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, @@ -103,26 +94,24 @@ async def set_job_statuses( @router.post("/reschedule") -async def reschedule_bulk_jobs( +async def reschedule_jobs( job_ids: Annotated[list[int], Query()], config: Config, job_db: JobDB, job_logging_db: JobLoggingDB, task_queue_db: TaskQueueDB, - background_task: BackgroundTasks, check_permissions: CheckWMSPolicyCallable, reset_jobs: Annotated[bool, Query()] = False, -): +) -> dict[str, Any]: await check_permissions(action=ActionType.MANAGE, job_db=job_db, job_ids=job_ids) - resched_jobs = await reschedule_jobs_bulk( + resched_jobs = await reschedule_jobs_bl( job_ids, config, job_db, job_logging_db, task_queue_db, - background_task, - reset_counter=reset_jobs, + reset_jobs=reset_jobs, ) if not resched_jobs.get("success", []): diff --git a/diracx-routers/src/diracx/routers/jobs/submission.py b/diracx-routers/src/diracx/routers/jobs/submission.py index cc1e0dea..54c9ff31 100644 --- a/diracx-routers/src/diracx/routers/jobs/submission.py +++ b/diracx-routers/src/diracx/routers/jobs/submission.py @@ -1,19 +1,13 @@ from __future__ import annotations -import logging -from datetime import datetime, timezone from http import HTTPStatus from typing import Annotated -from fastapi import Body, Depends, HTTPException, status +from fastapi import Body, Depends, HTTPException from pydantic import BaseModel -from typing_extensions import TypedDict -from diracx.core.models import ( - JobStatus, -) -from diracx.db.sql.job_logging.db import JobLoggingRecord -from diracx.db.sql.utils.job import JobSubmissionSpec, submit_jobs_jdl +from diracx.core.models import InsertedJob +from diracx.logic.jobs.submission import submit_jdl_jobs as submit_jdl_jobs_bl from ..dependencies import ( JobDB, @@ -23,18 +17,9 @@ from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token from .access_policies import ActionType, CheckWMSPolicyCallable -logger = logging.getLogger(__name__) - router = DiracxRouter() -class InsertedJob(TypedDict): - JobID: int - Status: str - MinorStatus: str - TimeStamp: datetime - - class JobID(BaseModel): job_id: int @@ -69,134 +54,23 @@ class JobID(BaseModel): @router.post("/jdl") -async def submit_bulk_jdl_jobs( +async def submit_jdl_jobs( job_definitions: Annotated[list[str], Body(openapi_examples=EXAMPLE_JDLS)], job_db: JobDB, job_logging_db: JobLoggingDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckWMSPolicyCallable, ) -> list[InsertedJob]: - + """Submit a list of jobs in JDL format.""" await check_permissions(action=ActionType.CREATE, job_db=job_db) - from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd - from DIRAC.WorkloadManagementSystem.Utilities.ParametricJob import ( - generateParametricJobs, - getParameterVectorLength, - ) - - # TODO: that needs to go in the legacy adapter (Does it ? Because bulk submission is not supported there) - for i in range(len(job_definitions)): - job_definition = job_definitions[i].strip() - if not (job_definition.startswith("[") and job_definition.endswith("]")): - job_definition = f"[{job_definition}]" - job_definitions[i] = job_definition - - if len(job_definitions) == 1: - # Check if the job is a parametric one - job_class_ad = ClassAd(job_definitions[0]) - result = getParameterVectorLength(job_class_ad) - if not result["OK"]: - # FIXME dont do this - print("Issue with getParameterVectorLength", result["Message"]) - return result - n_jobs = result["Value"] - parametric_job = False - if n_jobs is not None and n_jobs > 0: - # if we are here, then jobDesc was the description of a parametric job. So we start unpacking - parametric_job = True - result = generateParametricJobs(job_class_ad) - if not result["OK"]: - # FIXME why? - return result - job_desc_list = result["Value"] - else: - # if we are here, then jobDesc was the description of a single job. - job_desc_list = job_definitions - else: - # if we are here, then jobDesc is a list of JDLs - # we need to check that none of them is a parametric - for job_definition in job_definitions: - res = getParameterVectorLength(ClassAd(job_definition)) - if not res["OK"]: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail=res["Message"] - ) - if res["Value"]: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="You cannot submit parametric jobs in a bulk fashion", - ) - - job_desc_list = job_definitions - # parametric_job = True - parametric_job = False - - # TODO: make the max number of jobs configurable in the CS - if len(job_desc_list) > MAX_PARAMETRIC_JOBS: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Normal user cannot submit more than {MAX_PARAMETRIC_JOBS} jobs at once", - ) - - result = [] - - if parametric_job: - initial_status = JobStatus.SUBMITTING - initial_minor_status = "Bulk transaction confirmation" - else: - initial_status = JobStatus.RECEIVED - initial_minor_status = "Job accepted" - try: - submitted_job_ids = await submit_jobs_jdl( - [ - JobSubmissionSpec( - jdl=jdl, - owner=user_info.preferred_username, - owner_group=user_info.dirac_group, - initial_status=initial_status, - initial_minor_status=initial_minor_status, - vo=user_info.vo, - ) - for jdl in job_desc_list - ], - job_db=job_db, + inserted_jobs = await submit_jdl_jobs_bl( + job_definitions, job_db, job_logging_db, user_info ) - except ExceptionGroup as e: + except ValueError as e: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="JDL syntax error", + status_code=HTTPStatus.BAD_REQUEST, + detail=str(e), ) from e - - logging.debug( - f'Jobs added to the JobDB", "{submitted_job_ids} for {user_info.preferred_username}/{user_info.dirac_group}' - ) - - job_created_time = datetime.now(timezone.utc) - await job_logging_db.bulk_insert_record( - [ - JobLoggingRecord( - job_id=int(job_id), - status=initial_status, - minor_status=initial_minor_status, - application_status="Unknown", - date=job_created_time, - source="JobManager", - ) - for job_id in submitted_job_ids - ] - ) - - # if not parametric_job: - # self.__sendJobsToOptimizationMind(submitted_job_ids) - - return [ - InsertedJob( - JobID=job_id, - Status=initial_status, - MinorStatus=initial_minor_status, - TimeStamp=job_created_time, - ) - for job_id in submitted_job_ids - ] + return inserted_jobs diff --git a/diracx-routers/src/diracx/routers/utils/users.py b/diracx-routers/src/diracx/routers/utils/users.py index 69c2a4ae..96856e24 100644 --- a/diracx-routers/src/diracx/routers/utils/users.py +++ b/diracx-routers/src/diracx/routers/utils/users.py @@ -7,13 +7,11 @@ from authlib.jose import JoseError, JsonWebToken from fastapi import Depends, HTTPException, status from fastapi.security import OpenIdConnect -from pydantic import BaseModel, Field -from pydantic_settings import SettingsConfigDict +from pydantic import BaseModel from diracx.core.models import UserInfo from diracx.core.properties import SecurityProperty -from diracx.core.settings import FernetKey, ServiceSettingsBase, TokenSigningKey -from diracx.routers.dependencies import Config, add_settings_annotation +from diracx.routers.dependencies import AuthSettings # auto_error=False is used to avoid raising the wrong exception when the token is missing # The error is handled in the verify_dirac_access_token function @@ -43,35 +41,6 @@ class AuthorizedUserInfo(AuthInfo, UserInfo): pass -@add_settings_annotation -class AuthSettings(ServiceSettingsBase): - """Settings for the authentication service.""" - - model_config = SettingsConfigDict(env_prefix="DIRACX_SERVICE_AUTH_") - - dirac_client_id: str = "myDIRACClientID" - # TODO: This should be taken dynamically - # ["http://pclhcb211:8000/docs/oauth2-redirect"] - allowed_redirects: list[str] = [] - device_flow_expiration_seconds: int = 600 - authorization_flow_expiration_seconds: int = 300 - - # State key is used to encrypt/decrypt the state dict passed to the IAM - state_key: FernetKey - - # TODO: this should probably be something mandatory - # to set by the user - token_issuer: str = "http://lhcbdirac.cern.ch/" # noqa: S105 - token_key: TokenSigningKey - token_algorithm: str = "RS256" # noqa: S105 - access_token_expire_minutes: int = 20 - refresh_token_expire_minutes: int = 60 - - available_properties: set[SecurityProperty] = Field( - default_factory=SecurityProperty.available_properties - ) - - async def verify_dirac_access_token( authorization: Annotated[str, Depends(oidc_scheme)], settings: AuthSettings, @@ -119,12 +88,3 @@ async def verify_dirac_access_token( vo=token["vo"], policies=token.get("dirac_policies", {}), ) - - -def get_allowed_user_properties(config: Config, sub, vo: str) -> set[SecurityProperty]: - """Retrieve all properties of groups a user is registered in.""" - allowed_user_properties = set() - for group in config.Registry[vo].Groups: - if sub in config.Registry[vo].Groups[group].Users: - allowed_user_properties.update(config.Registry[vo].Groups[group].Properties) - return allowed_user_properties diff --git a/diracx-routers/tests/auth/test_standard.py b/diracx-routers/tests/auth/test_standard.py index d883e7bf..1d7c5e6a 100644 --- a/diracx-routers/tests/auth/test_standard.py +++ b/diracx-routers/tests/auth/test_standard.py @@ -13,21 +13,21 @@ from cryptography.fernet import Fernet from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey -from fastapi import HTTPException from pytest_httpx import HTTPXMock from diracx.core.config import Config +from diracx.core.exceptions import AuthorizationError +from diracx.core.models import GrantType from diracx.core.properties import NORMAL_USER, PROXY_MANAGEMENT, SecurityProperty -from diracx.routers.auth.token import create_token -from diracx.routers.auth.utils import ( - GrantType, +from diracx.core.settings import AuthSettings +from diracx.logic.auth.token import create_token +from diracx.logic.auth.utils import ( _server_metadata_cache, decrypt_state, encrypt_state, get_server_metadata, parse_and_validate_scope, ) -from diracx.routers.utils.users import AuthSettings DIRAC_CLIENT_ID = "myDIRACClientID" pytestmark = pytest.mark.enabled_dependencies( @@ -70,7 +70,7 @@ def custom_response(request: httpx.Request): httpx_mock.add_callback(custom_response, url=server_metadata["token_endpoint"]) - monkeypatch.setattr("diracx.routers.auth.utils.parse_id_token", fake_parse_id_token) + monkeypatch.setattr("diracx.logic.auth.utils.parse_id_token", fake_parse_id_token) yield httpx_mock @@ -613,6 +613,7 @@ async def test_refresh_token_invalid(test_client, auth_httpx_mock: HTTPXMock): ).decode() new_auth_settings = AuthSettings( + token_issuer="https://iam-auth.web.cern.ch/", token_algorithm="EdDSA", token_key=pem, state_key=Fernet.generate_key(), @@ -1039,7 +1040,6 @@ def test_encrypt_decrypt_state_invalid_state(fernet_key): """Test that decrypt_state raises an error when the state is invalid.""" state = "invalid_state" # Invalid state string - with pytest.raises(HTTPException) as exc_info: + with pytest.raises(AuthorizationError) as exc_info: decrypt_state(state, fernet_key) - assert exc_info.value.status_code == 400 assert exc_info.value.detail == "Invalid state" diff --git a/diracx-routers/tests/jobs/test_sandboxes.py b/diracx-routers/tests/jobs/test_sandboxes.py index 36763ea1..7a30ba48 100644 --- a/diracx-routers/tests/jobs/test_sandboxes.py +++ b/diracx-routers/tests/jobs/test_sandboxes.py @@ -9,8 +9,8 @@ import pytest from fastapi.testclient import TestClient +from diracx.core.settings import AuthSettings from diracx.routers.auth.token import create_token -from diracx.routers.utils.users import AuthSettings pytestmark = pytest.mark.enabled_dependencies( [ diff --git a/diracx-routers/tests/jobs/test_wms_access_policy.py b/diracx-routers/tests/jobs/test_wms_access_policy.py index 6df6e675..9852d788 100644 --- a/diracx-routers/tests/jobs/test_wms_access_policy.py +++ b/diracx-routers/tests/jobs/test_wms_access_policy.py @@ -32,11 +32,6 @@ def job_db(): yield FakeDB() -@pytest.fixture -def sandbox_db(): - yield FakeDB() - - WMS_POLICY_NAME = "WMSAccessPolicy_AlthoughItDoesNotMatter" SANDBOX_POLICY_NAME = "SandboxAccessPolicy_AlthoughItDoesNotMatter" @@ -225,25 +220,16 @@ async def summary_other_vo(*args): ) -async def test_sandbox_access_policy_create(sandbox_db): +async def test_sandbox_access_policy_create(): admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload) normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) - # sandbox_metadata_db and pfns are mandatory parameters - with pytest.raises(AssertionError): - await SandboxAccessPolicy.policy( - SANDBOX_POLICY_NAME, - normal_user, - action=ActionType.CREATE, - sandbox_metadata_db=sandbox_db, - ) + # action is a mandatory parameter with pytest.raises(AssertionError): await SandboxAccessPolicy.policy( SANDBOX_POLICY_NAME, normal_user, - action=ActionType.CREATE, - pfns=[USER_SANDBOX_PFN], ) # An admin cannot create any resource @@ -252,7 +238,6 @@ async def test_sandbox_access_policy_create(sandbox_db): SANDBOX_POLICY_NAME, admin_user, action=ActionType.CREATE, - sandbox_metadata_db=sandbox_db, pfns=[USER_SANDBOX_PFN], ) @@ -261,14 +246,13 @@ async def test_sandbox_access_policy_create(sandbox_db): SANDBOX_POLICY_NAME, normal_user, action=ActionType.CREATE, - sandbox_metadata_db=sandbox_db, pfns=[USER_SANDBOX_PFN], ) ############## -async def test_sandbox_access_policy_read(sandbox_db): +async def test_sandbox_access_policy_read(): admin_user = AuthorizedUserInfo(properties=[JOB_ADMINISTRATOR], **base_payload) normal_user = AuthorizedUserInfo(properties=[NORMAL_USER], **base_payload) @@ -277,7 +261,6 @@ async def test_sandbox_access_policy_read(sandbox_db): SANDBOX_POLICY_NAME, admin_user, action=ActionType.READ, - sandbox_metadata_db=sandbox_db, pfns=[USER_SANDBOX_PFN], required_prefix=SANDBOX_PREFIX, ) @@ -286,7 +269,6 @@ async def test_sandbox_access_policy_read(sandbox_db): SANDBOX_POLICY_NAME, admin_user, action=ActionType.READ, - sandbox_metadata_db=sandbox_db, pfns=[OTHER_USER_SANDBOX_PFN], required_prefix=SANDBOX_PREFIX, ) @@ -297,7 +279,6 @@ async def test_sandbox_access_policy_read(sandbox_db): SANDBOX_POLICY_NAME, normal_user, action=ActionType.READ, - sandbox_metadata_db=sandbox_db, pfns=[USER_SANDBOX_PFN], ) @@ -306,7 +287,6 @@ async def test_sandbox_access_policy_read(sandbox_db): SANDBOX_POLICY_NAME, normal_user, action=ActionType.READ, - sandbox_metadata_db=sandbox_db, pfns=[USER_SANDBOX_PFN], required_prefix=SANDBOX_PREFIX, ) @@ -317,7 +297,6 @@ async def test_sandbox_access_policy_read(sandbox_db): SANDBOX_POLICY_NAME, normal_user, action=ActionType.READ, - sandbox_metadata_db=sandbox_db, pfns=[OTHER_USER_SANDBOX_PFN], required_prefix=SANDBOX_PREFIX, ) diff --git a/diracx-routers/tests/test_generic.py b/diracx-routers/tests/test_generic.py index 659f0b5d..c4edd5c2 100644 --- a/diracx-routers/tests/test_generic.py +++ b/diracx-routers/tests/test_generic.py @@ -31,7 +31,7 @@ def test_openapi(test_client): def test_oidc_configuration(test_client): r = test_client.get("/.well-known/openid-configuration") - assert r.status_code == 200 + assert r.status_code == 200, r.json() assert r.json() diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index 0ae0d0b1..16aadfb4 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -1679,7 +1679,7 @@ def test_remove_job_invalid_job_id(normal_user_client: TestClient, invalid_job_i assert r.status_code == 200, r.json() -def test_remove_bulk_jobs_valid_job_ids( +def test_remove_jobs_valid_job_ids( normal_user_client: TestClient, valid_job_ids: list[int] ): # Act @@ -1786,7 +1786,7 @@ def test_set_single_job_properties_non_existing_job( assert res.status_code == HTTPStatus.NOT_FOUND, res.json() -# def test_remove_bulk_jobs_invalid_job_ids( +# def test_remove_jobs_invalid_job_ids( # normal_user_client: TestClient, invalid_job_ids: list[int] # ): # # Act @@ -1805,7 +1805,7 @@ def test_set_single_job_properties_non_existing_job( # } -# def test_remove_bulk_jobs_mix_of_valid_and_invalid_job_ids( +# def test_remove_jobs_mix_of_valid_and_invalid_job_ids( # normal_user_client: TestClient, valid_job_ids: list[int], invalid_job_ids: list[int] # ): # # Arrange diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index 45f8ae6b..25976c75 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -21,10 +21,15 @@ import httpx import pytest +from diracx.core.models import AccessTokenPayload, RefreshTokenPayload + if TYPE_CHECKING: - from diracx.core.settings import DevelopmentSettings - from diracx.routers.jobs.sandboxes import SandboxStoreSettings - from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings + from diracx.core.settings import ( + AuthSettings, + DevelopmentSettings, + SandboxStoreSettings, + ) + from diracx.routers.utils.users import AuthorizedUserInfo # to get a string like this run: @@ -92,9 +97,10 @@ def test_dev_settings() -> Generator[DevelopmentSettings, None, None]: def test_auth_settings( private_key_pem, fernet_key ) -> Generator[AuthSettings, None, None]: - from diracx.routers.utils.users import AuthSettings + from diracx.core.settings import AuthSettings yield AuthSettings( + token_issuer=ISSUER, token_algorithm="EdDSA", token_key=private_key_pem, state_key=fernet_key, @@ -128,7 +134,7 @@ def aio_moto(worker_id): @pytest.fixture(scope="session") def test_sandbox_settings(aio_moto) -> SandboxStoreSettings: - from diracx.routers.jobs.sandboxes import SandboxStoreSettings + from diracx.core.settings import SandboxStoreSettings yield SandboxStoreSettings( bucket_name="sandboxes", @@ -177,7 +183,9 @@ async def policy( pass @staticmethod - def enrich_tokens(access_payload: dict, refresh_payload: dict): + def enrich_tokens( + access_payload: AccessTokenPayload, refresh_payload: RefreshTokenPayload + ): return {"PolicySpecific": "OpenAccessForTest"}, {} @@ -186,13 +194,13 @@ def enrich_tokens(access_payload: dict, refresh_payload: dict): } database_urls = { e.name: "sqlite+aiosqlite:///:memory:" - for e in select_from_extension(group="diracx.db.sql") + for e in select_from_extension(group="diracx.dbs.sql") } # TODO: Monkeypatch this in a less stupid way # TODO: Only use this if opensearch isn't available os_database_conn_kwargs = { e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"} - for e in select_from_extension(group="diracx.db.os") + for e in select_from_extension(group="diracx.dbs.os") } BaseOSDB.available_implementations = partial( fake_available_osdb_implementations, diff --git a/docs/CODING_CONVENTION.md b/docs/CODING_CONVENTION.md index 50363485..4bb3536a 100644 --- a/docs/CODING_CONVENTION.md +++ b/docs/CODING_CONVENTION.md @@ -77,3 +77,10 @@ class Owners(Base): * `__init__.py` should not contain code, but `__all__` * at a package level (router for example) we have one file per system (configuration.py for example) * If we need more files (think of jobs, which have the sandbox, the joblogging, etc), we put them in a sub module (e.g routers.job). The code goes in a specific file (job.py, joblogging.py) but we use the the __init__.py to expose the specific file + + +# Architecture + +* `diracx-routers` should deal with user interactions through HTTPs. It is expected to deal with permissions and should call `diracx-logic`. Results returned should be translated into HTTP responses. +* `diracx-logic` should embed Dirac specificities. It should encapsulate the logic of the services and should call `diracx-db` to interact with databases. +* `diracx-db` should contain atomic methods (complex logic is expected to be located in `diracx-db`). diff --git a/docs/DATABASES.md b/docs/DATABASES.md index fbcfca03..435f15f9 100644 --- a/docs/DATABASES.md +++ b/docs/DATABASES.md @@ -2,7 +2,7 @@ The primary store of operational data in DiracX is in SQL databases managed through SQLAlchemy. In addition, DiracX utilizes OpenSearch (or Elasticsearch) for storing pilot logs, medium-term metadata about jobs and pilots ("Job Parameters" and "Pilot Parameters"), and optionally, for OpenTelemetry data. -Access to databases is managed by the `diracx.dbs` package. +Access to databases is managed by the `diracx-db` package. ## SQL Databases @@ -15,7 +15,7 @@ For convenience SQLite is used for testing and development however this should n Connections to DiracX DBs are configured using the [SQLAlchemy connection URL format](https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls). DiracX requires that the `driver` part of the URL is always specified and it must refer to an async-compatible backend. -The value of this URL is taken from the environment vairable of the form `DIRACX_DB_URL_`, where `` is defined by the entry in the `diracx.db.sql` entrypoint in the `pyproject.toml`. +The value of this URL is taken from the environment vairable of the form `DIRACX_DB_URL_`, where `` is defined by the entry in the `diracx.dbs.sql` entrypoint in the `pyproject.toml`. ```bash export DIRACX_DB_URL_MYDB="mysql+aiomysql://user:pass@hostname:3306/MyDB" diff --git a/docs/SERVICES.md b/docs/SERVICES.md index c436f694..5e529ccb 100644 --- a/docs/SERVICES.md +++ b/docs/SERVICES.md @@ -61,7 +61,7 @@ Usage example: ```python @router.get("/openid-configuration") -async def openid_configuration(settings: AuthSettings): +async def get_openid_configuration(settings: AuthSettings): ... ``` diff --git a/docs/TESTING.md b/docs/TESTING.md index e06c82a7..3638456c 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -4,6 +4,7 @@ What we do * run the integration tests (against the demo) in a single job For the unit test, we start with a crude conda environment, and do pip install of the package. +Note: `diracx-logic` does not contain any unit tests, developers are expected to run tests from `diracx-routers`. For the integration tests, we always use the [services|tasks|client] dev image and do a pip install directly with ``--no-deps``. diff --git a/docs/VERSIONING.md b/docs/VERSIONING.md index 3dd0a262..f3b870de 100644 --- a/docs/VERSIONING.md +++ b/docs/VERSIONING.md @@ -15,8 +15,9 @@ DiracX is a comprehensive Python package, composed of several interconnected sub DiracX is structured into various modules, each serving a distinct purpose: - **`diracx-core`**: The foundational code base, utilized by all other DiracX modules. -- **`diracx-db`**: Focuses on database functionalities. -- **`diracx-routers`**: Implements a FastAPI application. +- **`diracx-db`**: Data Access Layer, focuses on database functionalities. +- **`diracx-logic`**: Business Logic Layer, comprises Dirac logic. +- **`diracx-routers`**: Presentation Layer, handles user interactions through HTTP using a FastAPI application. - **`diracx-client`**: A client auto-generated from the OpenAPI specification in `diracx-routers`. - **`diracx-api`**: Provides higher-level operations building on `diracx-client`. - **`diracx-cli`**: The command line interface (`dirac`). @@ -26,19 +27,47 @@ These modules are each implemented as a [native Python namespace package](https: The direct dependencies between the submodules are as follows: -``` - ┌──────┐ - ┌──────┤ core ├─────────┐ - │ └──────┘ │ - ┌──▼─┐ ┌────▼───┐ - │ db ├─────┐ │ client │ - └─┬──┘ │ └────┬───┘ -┌────▼────┐ │ ┌──▼──┐ -│ routers │ │ ┌────────┤ api │ -└─────────┘ │ │ └──┬──┘ - ┌─▼───▼─┐ ┌──▼──┐ - │ tasks │ │ cli │ - └───────┘ └─────┘ +```mermaid +--- +config: + layout: elk +--- +flowchart BT + subgraph frontend["Frontend"] + client["diracx-client (autorest)"] + api["diracx-api"] + cli["diracx-cli (typer)"] + end + subgraph backend["Backend"] + dbs["diracx-db (sqlalchemy/os)"] + logic["diracx-logic (Dirac)"] + routers["diracx-routers (FastAPI)"] + end + dbs -. uses .-> core["diracx-core (domain)"] + logic -. uses .-> core + routers -. uses .-> core + tasks["diracx-tasks (celery?)"] -. uses .-> core + client -. uses .-> core + api -. uses .-> core + cli -. uses .-> core + logic -- calls --> dbs + routers -- calls --> logic + tasks -- calls --> logic & api + client -- calls through OpenAPI --> routers + api -- calls --> client + cli -- calls --> api & client + client:::Sky + api:::Sky + cli:::Sky + dbs:::Pine + logic:::Pine + routers:::Pine + tasks:::Aqua + classDef Rose stroke-width:1px, stroke-dasharray:none, stroke:#FF5978, fill:#FFDFE5, color:#8E2236 + classDef Sky stroke-width:1px, stroke-dasharray:none, stroke:#374D7C, fill:#E2EBFF, color:#374D7C + classDef Pine stroke-width:1px, stroke-dasharray:none, stroke:#254336, fill:#27654A, color:#FFFFFF + classDef Aqua stroke-width:1px, stroke-dasharray:none, stroke:#46EDC8, fill:#DEFFF8, color:#378E7A + ``` diff --git a/extensions/gubbins/gubbins-core/src/gubbins/core/__init__.py b/extensions/gubbins/gubbins-core/src/gubbins/core/__init__.py index b05c9dee..28d47ffb 100644 --- a/extensions/gubbins/gubbins-core/src/gubbins/core/__init__.py +++ b/extensions/gubbins/gubbins-core/src/gubbins/core/__init__.py @@ -1 +1 @@ -__all__ = ("config", "properties") +__all__ = ("config", "properties", "models") diff --git a/extensions/gubbins/gubbins-core/src/gubbins/core/models.py b/extensions/gubbins/gubbins-core/src/gubbins/core/models.py new file mode 100644 index 00000000..1a0e4c73 --- /dev/null +++ b/extensions/gubbins/gubbins-core/src/gubbins/core/models.py @@ -0,0 +1,6 @@ +from diracx.core.models import Metadata + + +class ExtendedMetadata(Metadata): + gubbins_secrets: str + gubbins_user_info: dict[str, list[str | None]] diff --git a/extensions/gubbins/gubbins-db/pyproject.toml b/extensions/gubbins/gubbins-db/pyproject.toml index 127196da..3e11fda5 100644 --- a/extensions/gubbins/gubbins-db/pyproject.toml +++ b/extensions/gubbins/gubbins-db/pyproject.toml @@ -22,7 +22,7 @@ dynamic = ["version"] [project.optional-dependencies] testing = ["gubbins-testing", "diracx-testing"] -[project.entry-points."diracx.db.sql"] +[project.entry-points."diracx.dbs.sql"] LollygagDB = "gubbins.db.sql:LollygagDB" JobDB = "gubbins.db.sql:GubbinsJobDB" diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py index 414bc23d..cfdd6d9b 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py @@ -47,7 +47,7 @@ async def get_job_jdls( # type: ignore[override] result[job_id] = {"JDL": jdl_details, "Info": info.get(job_id, "")} return result - async def set_job_attributes_bulk(self, job_data): + async def set_job_attributes(self, job_data): """ This method modified the one in the parent class, without changing the argument nor the return type diff --git a/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py b/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py index a9d21362..37a3e793 100644 --- a/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py +++ b/extensions/gubbins/gubbins-db/tests/test_gubbins_job_db.py @@ -3,7 +3,6 @@ from typing import AsyncGenerator import pytest -from diracx.db.sql.utils.job import JobSubmissionSpec, submit_jobs_jdl from gubbins.db.sql import GubbinsJobDB @@ -20,6 +19,13 @@ async def gubbins_db() -> AsyncGenerator[GubbinsJobDB, None]: yield gubbins_db +@pytest.fixture +async def populated_job_db(job_db): + """Populate the in-memory JobDB with 100 jobs using DAL calls.""" + + yield job_db + + async def test_gubbins_info(gubbins_db): """ This test makes sure that we can: @@ -28,25 +34,24 @@ async def test_gubbins_info(gubbins_db): * use a method modified in the child db (getJobJDL) """ async with gubbins_db as gubbins_db: - job_ids = await submit_jobs_jdl( - [ - JobSubmissionSpec( - jdl="JDL", - owner="owner_toto", - owner_group="owner_group1", - initial_status="New", - initial_minor_status="dfdfds", - vo="lhcb", - ) - ], - gubbins_db, - ) - await gubbins_db.insert_gubbins_info(job_ids[0], "info") - - result = await gubbins_db.get_job_jdls(job_ids, original=True) - assert result == {1: "[JDL]"} - - result = await gubbins_db.get_job_jdls(job_ids, with_info=True) + compressed_jdl = "CompressedJDL" + job_id = await gubbins_db.create_job(compressed_jdl) + job_attr = { + "JobID": job_id, + "Status": "New", + "MinorStatus": "dfdfds", + "Owner": "owner_toto", + "OwnerGroup": "owner_group1", + "VO": "lhcb", + } + await gubbins_db.insert_job_attributes({job_id: job_attr}) + + await gubbins_db.insert_gubbins_info(job_id, "info") + + result = await gubbins_db.get_job_jdls([job_id], original=True) + assert result == {1: "CompressedJDL"} + + result = await gubbins_db.get_job_jdls([job_id], original=True, with_info=True) assert len(result) == 1 assert result[1].get("JDL") assert result[1].get("Info") == "info" diff --git a/extensions/gubbins/gubbins-logic/pyproject.toml b/extensions/gubbins/gubbins-logic/pyproject.toml new file mode 100644 index 00000000..c19040c0 --- /dev/null +++ b/extensions/gubbins/gubbins-logic/pyproject.toml @@ -0,0 +1,45 @@ +[project] +name = "gubbins-logic" +description = "TODO" +readme = "README.md" +requires-python = ">=3.11" +keywords = [] +license = { text = "GPL-3.0-only" } +classifiers = [ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering", + "Topic :: System :: Distributed Computing", +] + +dependencies = [ + # This is obvious + "diracx-logic", + # We should add something else +] + +dynamic = ["version"] + +[project.optional-dependencies] +types = [ + "boto3-stubs", + "types-aiobotocore[essential]", + "types-aiobotocore-s3", + "types-cachetools", + "types-python-dateutil", + "types-PyYAML", +] + +[tool.setuptools.packages.find] +where = ["src"] + +[build-system] +requires = ["setuptools>=61", "wheel", "setuptools_scm>=8"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +root = "../../.." + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/extensions/gubbins/gubbins-logic/src/gubbins/logic/__init__.py b/extensions/gubbins/gubbins-logic/src/gubbins/logic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/gubbins/gubbins-logic/src/gubbins/logic/auth/__init__.py b/extensions/gubbins/gubbins-logic/src/gubbins/logic/auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/gubbins/gubbins-logic/src/gubbins/logic/auth/well_known.py b/extensions/gubbins/gubbins-logic/src/gubbins/logic/auth/well_known.py new file mode 100644 index 00000000..65c7c3f8 --- /dev/null +++ b/extensions/gubbins/gubbins-logic/src/gubbins/logic/auth/well_known.py @@ -0,0 +1,28 @@ +from diracx.logic.auth.well_known import ( + get_installation_metadata as get_general_installation_metadata, +) + +from gubbins.core.config.schema import Config +from gubbins.core.models import ExtendedMetadata + + +async def get_installation_metadata( + config: Config, +) -> ExtendedMetadata: + """Get metadata about the dirac installation.""" + original_metadata = await get_general_installation_metadata(config) + + gubbins_user_info: dict[str, list[str | None]] = {} + for vo in config.Registry: + vo_gubbins = [ + user.GubbinsSpecificInfo for user in config.Registry[vo].Users.values() + ] + gubbins_user_info[vo] = vo_gubbins + + gubbins_metadata = ExtendedMetadata( + gubbins_secrets="hush!", + virtual_organizations=original_metadata["virtual_organizations"], + gubbins_user_info=gubbins_user_info, + ) + + return gubbins_metadata diff --git a/extensions/gubbins/gubbins-logic/src/gubbins/logic/lollygag/__init__.py b/extensions/gubbins/gubbins-logic/src/gubbins/logic/lollygag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions/gubbins/gubbins-logic/src/gubbins/logic/lollygag/lollygag.py b/extensions/gubbins/gubbins-logic/src/gubbins/logic/lollygag/lollygag.py new file mode 100644 index 00000000..0fdbad8f --- /dev/null +++ b/extensions/gubbins/gubbins-logic/src/gubbins/logic/lollygag/lollygag.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from gubbins.db.sql import LollygagDB + + +async def insert_owner_object( + lollygag_db: LollygagDB, + owner_name: str, +): + return await lollygag_db.insert_owner(owner_name) + + +async def get_owner_object( + lollygag_db: LollygagDB, +): + return await lollygag_db.get_owner() + + +async def get_gubbins_secrets( + lollygag_db: LollygagDB, +): + """Does nothing but expects a GUBBINS_SENSEI permission""" + return await lollygag_db.get_owner() diff --git a/extensions/gubbins/gubbins-logic/src/gubbins/logic/py.typed b/extensions/gubbins/gubbins-logic/src/gubbins/logic/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/extensions/gubbins/gubbins-routers/src/gubbins/routers/lollygag/lollygag.py b/extensions/gubbins/gubbins-routers/src/gubbins/routers/lollygag/lollygag.py index 1e69c06c..91b5014f 100644 --- a/extensions/gubbins/gubbins-routers/src/gubbins/routers/lollygag/lollygag.py +++ b/extensions/gubbins/gubbins-routers/src/gubbins/routers/lollygag/lollygag.py @@ -11,6 +11,13 @@ from fastapi import Depends from gubbins.db.sql import LollygagDB as _LollygagDB +from gubbins.logic.lollygag.lollygag import ( + get_gubbins_secrets as get_gubbins_secrets_bl, +) +from gubbins.logic.lollygag.lollygag import get_owner_object as get_owner_object_bl +from gubbins.logic.lollygag.lollygag import ( + insert_owner_object as insert_owner_object_bl, +) from .access_policy import ActionType, CheckLollygagPolicyCallable @@ -28,7 +35,7 @@ async def insert_owner_object( check_permission: CheckLollygagPolicyCallable, ): await check_permission(action=ActionType.CREATE) - return await lollygag_db.insert_owner(owner_name) + return await insert_owner_object_bl(lollygag_db, owner_name) @router.get("/get_owners") @@ -37,7 +44,7 @@ async def get_owner_object( check_permission: CheckLollygagPolicyCallable, ): await check_permission(action=ActionType.READ) - return await lollygag_db.get_owner() + return await get_owner_object_bl(lollygag_db) @router.get("/gubbins_sensei") @@ -47,4 +54,4 @@ async def get_gubbins_secrets( ): """Does nothing but expects a GUBBINS_SENSEI permission""" await check_permission(action=ActionType.MANAGE) - return await lollygag_db.get_owner() + return await get_gubbins_secrets_bl(lollygag_db) diff --git a/extensions/gubbins/gubbins-routers/src/gubbins/routers/well_known.py b/extensions/gubbins/gubbins-routers/src/gubbins/routers/well_known.py index c7035362..a443c2a7 100644 --- a/extensions/gubbins/gubbins-routers/src/gubbins/routers/well_known.py +++ b/extensions/gubbins/gubbins-routers/src/gubbins/routers/well_known.py @@ -6,47 +6,23 @@ * uses the Gubbins dependencies """ -from diracx.routers.auth.well_known import Metadata -from diracx.routers.auth.well_known import ( - installation_metadata as _installation_metadata, -) from diracx.routers.auth.well_known import router as diracx_wellknown_router -from diracx.routers.dependencies import DevelopmentSettings from diracx.routers.fastapi_classes import DiracxRouter +from gubbins.core.models import ExtendedMetadata +from gubbins.logic.auth.well_known import ( + get_installation_metadata as get_installation_metadata_bl, +) from gubbins.routers.dependencies import Config router = DiracxRouter(require_auth=False, path_root="") router.include_router(diracx_wellknown_router) -# Change slightly the return type -class ExtendedMetadata(Metadata): - gubbins_secrets: str - gubbins_user_info: dict[str, list[str | None]] - - # Overwrite the dirac-metadata endpoint and add an extra metadata # This also makes sure that we can get Config as a GubbinsConfig @router.get("/dirac-metadata") -async def installation_metadata( +async def get_installation_metadata( config: Config, - dev_settings: DevelopmentSettings, ) -> ExtendedMetadata: - original_metadata = await _installation_metadata(config, dev_settings) - - gubbins_user_info: dict[str, list[str | None]] = {} - for vo in config.Registry: - vo_gubbins = [ - user.GubbinsSpecificInfo for user in config.Registry[vo].Users.values() - ] - gubbins_user_info[vo] = vo_gubbins - - gubbins_metadata = ExtendedMetadata( - gubbins_secrets="hush!", - virtual_organizations=original_metadata["virtual_organizations"], - development_settings=original_metadata["development_settings"], - gubbins_user_info=gubbins_user_info, - ) - - return gubbins_metadata + return await get_installation_metadata_bl(config) diff --git a/extensions/gubbins/pyproject.toml b/extensions/gubbins/pyproject.toml index c61127cb..a57b5b6a 100644 --- a/extensions/gubbins/pyproject.toml +++ b/extensions/gubbins/pyproject.toml @@ -82,6 +82,7 @@ files = [ "gubbins-client/src/gubbins/client/patches/**/*.py", # "gubbins-core/src/**/*.py", "gubbins-db/src/**/*.py", + "gubbins-logic/src/**/*.py", "gubbins-routers/src/**/*.py", ] mypy_path = [ @@ -90,6 +91,7 @@ mypy_path = [ "$MYPY_CONFIG_FILE_DIR/gubbins-client/src", # "$MYPY_CONFIG_FILE_DIR/gubbins-core/src", "$MYPY_CONFIG_FILE_DIR/gubbins-db/src", + "$MYPY_CONFIG_FILE_DIR/gubbins-logic/src", "$MYPY_CONFIG_FILE_DIR/gubbins-routers/src", ] plugins = ["sqlalchemy.ext.mypy.plugin", "pydantic.mypy"] diff --git a/pyproject.toml b/pyproject.toml index 2429d06d..3980ee87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,12 @@ classifiers = [ "Topic :: Scientific/Engineering", "Topic :: System :: Distributed Computing", ] -dependencies = ["diracx-api", "diracx-cli", "diracx-client", "diracx-core"] +dependencies = [ + "diracx-api", + "diracx-cli", + "diracx-client", + "diracx-core", +] dynamic = ["version"] [project.optional-dependencies] @@ -102,6 +107,7 @@ files = [ "diracx-client/src/**/_patch.py", "diracx-core/src/**/*.py", "diracx-db/src/**/*.py", + "diracx-logic/src/**/*.py", "diracx-routers/src/**/*.py", ] mypy_path = [ @@ -110,6 +116,7 @@ mypy_path = [ "$MYPY_CONFIG_FILE_DIR/diracx-client/src", "$MYPY_CONFIG_FILE_DIR/diracx-core/src", "$MYPY_CONFIG_FILE_DIR/diracx-db/src", + "$MYPY_CONFIG_FILE_DIR/diracx-logic/src", "$MYPY_CONFIG_FILE_DIR/diracx-routers/src", ] plugins = ["sqlalchemy.ext.mypy.plugin", "pydantic.mypy"] diff --git a/run_local.sh b/run_local.sh index 83bfbcc4..b922b5f1 100755 --- a/run_local.sh +++ b/run_local.sh @@ -38,6 +38,7 @@ export DIRACX_OS_DB_JOBPARAMETERSDB='{"sqlalchemy_dsn": "sqlite+aiosqlite:///'${ export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${signing_key}" export DIRACX_SERVICE_AUTH_STATE_KEY="${state_key}" hostname_lower=$(hostname | tr -s '[:upper:]' '[:lower:]') +export DIRACX_SERVICE_AUTH_TOKEN_ISSUER="http://${hostname_lower}:8000" export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'"$hostname_lower"':8000/docs/oauth2-redirect"]' export DIRACX_SANDBOX_STORE_BUCKET_NAME=sandboxes export DIRACX_SANDBOX_STORE_AUTO_CREATE_BUCKET=true diff --git a/tests/make_token_local.py b/tests/make_token_local.py index 3ddcd204..b02a6789 100755 --- a/tests/make_token_local.py +++ b/tests/make_token_local.py @@ -8,9 +8,9 @@ from diracx.core.models import TokenResponse from diracx.core.properties import NORMAL_USER +from diracx.core.settings import AuthSettings from diracx.core.utils import write_credentials -from diracx.routers.auth.token import create_token -from diracx.routers.utils.users import AuthSettings +from diracx.logic.auth.token import create_token def parse_args():