From 6aae6638bf13b7bc92711b1b7512a866d692013e Mon Sep 17 00:00:00 2001 From: Russ Allbery Date: Thu, 5 Dec 2024 17:04:11 -0800 Subject: [PATCH] Convert the UWS library to use the Wobbly backend Rather than storing UWS jobs directly in a database, which requires every UWS-based application to manage its own separate database, use the Wobbly service to manage all job storage. This service uses a delegated token to determine the user and service, so considerably less tracking of the user is required. UWS applications now store the serialized parameter model in the database rather than a list of key/value pairs, and rely on methods on the parameters model to convert to the XML format for the current IVOA UWS protocol. Add a mock for Wobbly that can be used to test UWS applications without having the Wobbly API available. Drop the `ErrorCode` enum, since its values were specific to SODA, and instead take the error code as a string. Drop some related exceptions that are not used directly in Safir and are specific to SODA. --- changelog.d/20241209_145305_rra_DM_47986.md | 7 + docs/_rst_epilog.rst | 2 + docs/conf.py | 4 + docs/user-guide/database/schema.rst | 2 - docs/user-guide/uws/create-a-service.rst | 104 +-- docs/user-guide/uws/define-inputs.rst | 69 +- docs/user-guide/uws/define-models.rst | 43 +- docs/user-guide/uws/index.rst | 4 +- docs/user-guide/uws/testing.rst | 155 ++-- docs/user-guide/uws/write-backend.rst | 2 +- noxfile.py | 7 +- safir-arq/src/safir/arq/uws.py | 4 +- safir/pyproject.toml | 3 - safir/src/safir/testing/uws.py | 306 ++++++-- safir/src/safir/uws/__init__.py | 48 +- safir/src/safir/uws/_app.py | 134 +--- safir/src/safir/uws/_config.py | 77 +- safir/src/safir/uws/_constants.py | 16 +- safir/src/safir/uws/_dependencies.py | 64 +- safir/src/safir/uws/_exceptions.py | 71 +- safir/src/safir/uws/_handlers.py | 100 +-- safir/src/safir/uws/_models.py | 717 ++++++++++++------ safir/src/safir/uws/_responses.py | 8 +- safir/src/safir/uws/_results.py | 10 +- safir/src/safir/uws/_schema.py | 101 --- safir/src/safir/uws/_service.py | 381 +++++----- safir/src/safir/uws/_storage.py | 536 +++++++------ safir/src/safir/uws/_workers.py | 90 +-- safir/src/safir/uws/templates/error.xml | 2 +- safir/tests/data/database/uws/README.md | 9 - safir/tests/data/database/uws/alembic.ini | 17 - safir/tests/data/database/uws/alembic/env.py | 25 - .../data/database/uws/alembic/script.py.mako | 26 - .../20240911_0000_e9299566bc19_uws_schema.py | 127 ---- safir/tests/support/uws.py | 23 +- safir/tests/uws/alembic_test.py | 65 -- safir/tests/uws/conftest.py | 77 +- safir/tests/uws/errors_test.py | 66 +- safir/tests/uws/job_api_test.py | 249 ++---- safir/tests/uws/job_error_test.py | 78 +- safir/tests/uws/job_list_test.py | 96 +-- safir/tests/uws/long_polling_test.py | 50 +- safir/tests/uws/post_params_test.py | 30 - safir/tests/uws/workers_test.py | 129 ++-- 44 files changed, 1852 insertions(+), 2282 deletions(-) create mode 100644 changelog.d/20241209_145305_rra_DM_47986.md delete mode 100644 safir/src/safir/uws/_schema.py delete mode 100644 safir/tests/data/database/uws/README.md delete mode 100644 safir/tests/data/database/uws/alembic.ini delete mode 100644 safir/tests/data/database/uws/alembic/env.py delete mode 100644 safir/tests/data/database/uws/alembic/script.py.mako delete mode 100644 safir/tests/data/database/uws/alembic/versions/20240911_0000_e9299566bc19_uws_schema.py delete mode 100644 safir/tests/uws/alembic_test.py delete mode 100644 safir/tests/uws/post_params_test.py diff --git a/changelog.d/20241209_145305_rra_DM_47986.md b/changelog.d/20241209_145305_rra_DM_47986.md new file mode 100644 index 00000000..88c93840 --- /dev/null +++ b/changelog.d/20241209_145305_rra_DM_47986.md @@ -0,0 +1,7 @@ +### Backwards-incompatible changes + +- Rewrite the Safir UWS support to use Pydantic models for job parameters. Services built on the Safir UWS library will need to change all job creation dependencies to return Pydantic models. +- Use the Wobbly service rather than a direct database connection to store UWS job information. Services built on the Safir UWS library must now configure a Wobbly URL and will switch to Wobbly's storage instead of their own when updated to this release of Safir. +- Support an execution duration of 0 in the Safir UWS library, mapping it to no limit on the execution duration. Note that this will not be allowed by the default configuration and must be explicitly allowed by an execution duration validation hook. +- Convert all models returned by the Safir UWS library to Pydantic. Services built on the Safir UWS library will have to change the types of validator functions for destruction time and execution duration. +- Safir no longer provides the `safir.uws.ErrorCode` enum or the exception `safir.uws.MultiValuedParameterError`. These values were specific to a SODA service, and different IVOA UWS services use different error codes. The Safir UWS library now takes error code as a string, and each application should define its own set of error codes in accordance with the IVOA standard it is implementing. diff --git a/docs/_rst_epilog.rst b/docs/_rst_epilog.rst index 8efda2a6..4f676d48 100644 --- a/docs/_rst_epilog.rst +++ b/docs/_rst_epilog.rst @@ -16,6 +16,7 @@ .. _pre-commit: https://pre-commit.com .. _Pydantic: https://docs.pydantic.dev/latest/ .. _Pydantic BaseSettings: https://docs.pydantic.dev/latest/concepts/pydantic_settings +.. _pydantic-xml: https://pydantic-xml.readthedocs.io/en/latest/ .. _PyPI: https://pypi.org/project/safir/ .. _pytest: https://docs.pytest.org/en/latest/ .. _redis-py: https://redis.readthedocs.io/en/stable/ @@ -33,3 +34,4 @@ .. _Uvicorn: https://www.uvicorn.org/ .. _virtualenvwrapper: https://virtualenvwrapper.readthedocs.io/en/stable/ .. _vo-models: https://vo-models.readthedocs.io/latest/ +.. _Wobbly: https://github.com/lsst-sqre/wobbly/ diff --git a/docs/conf.py b/docs/conf.py index af24870e..74b31c3a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,7 @@ from documenteer.conf.guide import * +# Disable JSON schema because it doesn't seem that useful and apparently can't +# deal with generics, so it produces warnings for the UWS Job model. +autodoc_pydantic_model_show_json = False + html_sidebars["api"] = [] # no sidebar on the API page diff --git a/docs/user-guide/database/schema.rst b/docs/user-guide/database/schema.rst index 80dab1b8..9747f449 100644 --- a/docs/user-guide/database/schema.rst +++ b/docs/user-guide/database/schema.rst @@ -7,7 +7,6 @@ Safir provides some additional supporting functions to make using Alembic more s These instructions assume that you have already defined your schema with SQLAlchemy's ORM model. If you have not already done that, do that first. -For UWS applications that only have the UWS database, the declarative base of the schema is `safir.uws.UWSSchemaBase`. Set up Alembic ============== @@ -85,7 +84,6 @@ Replace :file:`alembic/env.py` with the following: ) Replace ``example`` with the module name and application name of your application as appropriate. -For applications that only use the UWS database, replace ``example.schema.Base`` in the above with `safir.uws.UWSSchemaBase`. Add Alembic to the Docker image ------------------------------- diff --git a/docs/user-guide/uws/create-a-service.rst b/docs/user-guide/uws/create-a-service.rst index aac3f3b7..5ea3177a 100644 --- a/docs/user-guide/uws/create-a-service.rst +++ b/docs/user-guide/uws/create-a-service.rst @@ -10,8 +10,8 @@ Select the ``UWS`` flavor. Then, flesh out the application by following these steps: -#. :doc:`Define the API parameters ` #. :doc:`Define the parameter models ` +#. :doc:`Define the API parameters ` #. :doc:`Write the backend worker ` #. :doc:`Write the test suite ` @@ -40,7 +40,10 @@ This will add standard configuration options most services will need and provide Second, add a property to ``Config`` that returns the UWS configuration. For some of these settings, you won't know the values yet. -You will be able to fill in the value of ``parameters_type`` after reading :doc:`define-models`, the values of ``async_post_route`` and optionally ``sync_get_route`` and ``sync_post_route`` after reading :doc:`define-inputs`, and the value of ``worker`` after reading :doc:`write-backend`. + +You will be able to fill in the values of ``job_summary_type`` and ``parameters_type`` after reading :doc:`define-models`. +You will be able to fill in the values of ``async_post_route`` and optionally ``sync_get_route`` and ``sync_post_route`` after reading :doc:`define-inputs`. +You will be able to fill in the value of ``worker`` after reading :doc:`write-backend`. For now, you can just insert placeholder values. .. code-block:: python @@ -88,11 +91,12 @@ Set up the FastAPI application The Safir UWS library must be initialized when the application starts, and requires some additional FastAPI middleware and error handlers. These need to be added to :file:`main.py`. -First, initialize the UWS application in the ``lifespan`` function: +First, initialize and shut down the UWS application in the ``lifespan`` function: .. code-block:: python :caption: main.py - :emphasize-lines: 1,6,8 + + from safir.dependencies.http_client import http_client_dependency from .config import uws @@ -104,7 +108,7 @@ First, initialize the UWS application in the ``lifespan`` function: await uws.shutdown_fastapi() await http_client_dependency.aclose() -Second, install the UWS routes into the external router before including it in the application: +Second, install the UWS routes into the external router **before** including it in the application: .. code-block:: python :caption: main.py @@ -128,94 +132,6 @@ Third, install the UWS middleware and error handlers. # Install error handlers. uws.install_error_handlers(app) -Add a command-line interface -============================ - -The UWS implementation uses a PostgreSQL database to store job status. -Your application will need a mechanism to initialize that database with the desired schema. -The simplest way to do this is to add a command-line interface for your application with an ``init`` command that initializes the database. - -.. note:: - - This approach has inherent race conditions and cannot handle database schema upgrades. - It will be replaced with a more sophisticated approach using Alembic_ once that support is ready. - -First, create a new :file:`cli.py` file in your application with the following contents: - -.. code-block:: python - :caption: cli.py - - import click - import structlog - from safir.asyncio import run_with_asyncio - from safir.click import display_help - - from .config import uws - - - @click.group(context_settings={"help_option_names": ["-h", "--help"]}) - @click.version_option(message="%(version)s") - def main() -> None: - """Administrative command-line interface for example.""" - - - @main.command() - @click.argument("topic", default=None, required=False, nargs=1) - @click.pass_context - def help(ctx: click.Context, topic: str | None) -> None: - """Show help for any command.""" - display_help(main, ctx, topic) - - - @main.command() - @click.option( - "--reset", is_flag=True, help="Delete all existing database data." - ) - @run_with_asyncio - async def init(*, reset: bool) -> None: - """Initialize the database storage.""" - logger = structlog.get_logger("example") - await uws.initialize_uws_database(logger, reset=reset) - -Look for the instances of ``example`` and replace them with the name of your application. - -Second, register this interface with Python in :file:`pyproject.toml`: - -.. code-block:: toml - :caption: pyproject.toml - - [project.scripts] - example = "example.cli:main" - -Again, replace ``example`` with the name of your application. - -Third, change the :file:`Dockerfile` for your application to run a startup script rather than run :command:`uvicorn` directly: - -.. code-block:: docker - :caption: Dockerfile - - # Copy the startup script - COPY scripts/start-frontend.sh /start-frontend.sh - - # Run the application. - CMD ["/start-frontend.sh"] - -Finally, create the :file:`scripts/start-frontend.sh` file: - -.. code-block:: bash - :caption: scripts/start-frontend.sh - - #!/bin/bash - # - # Create the database and then start the server. - - set -eu - - example init - uvicorn example.main:app --host 0.0.0.0 --port 8080 - -Again, replace ``example`` with the name of your application. - Create the arq worker for database updates ========================================== @@ -248,6 +164,6 @@ Next steps Now that you have set up the basic structure of your application, you can move on to the substantive parts. -- Define the API parameters: :doc:`define-inputs` - Define the parameter models: :doc:`define-models` +- Define the API parameters: :doc:`define-inputs` - Write the backend worker :doc:`write-backend` diff --git a/docs/user-guide/uws/define-inputs.rst b/docs/user-guide/uws/define-inputs.rst index 4a543a88..71f6e19a 100644 --- a/docs/user-guide/uws/define-inputs.rst +++ b/docs/user-guide/uws/define-inputs.rst @@ -6,25 +6,8 @@ Defining service inputs Your UWS service will take one more more input parameters. The UWS library cannot know what those parameters are, so you will need to define them and pass that configuration into the UWS library configuration. -This is done by writing a FastAPI dependency that returns a list of input parameters as key/value pairs. - -What parameters look like -========================= - -UWS input parameters for a job are a list of key/value pairs. -The value is always a string. -Other data types are not directly supported. -If your service needs a different data type as a parameter value, you will need to accept it as a string and then parse it into a more complex structure. -See :doc:`define-models` for how to do that. - -All FastAPI dependencies provided by your application must return a list of `UWSJobParameter` objects. -The ``parameter_id`` attribute is the key and the ``value`` attribute is the value. - -The key (the ``parameter_id``) is case-insensitive in the input, but it will be lowercased by middleware installed by Safir. -You will therefore always see lowercase query and form parameters in your dependency and do not have to handle other case possibilities. - -UWS allows the same ``parameter_id`` to occur multiple times with different values. -For example, multiple ``id`` parameters may specify multiple input objects for a bulk operation that processes all of the input objects at the same time. +This is done by writing a FastAPI dependency that returns a Pydantic model for your job parameters. +See :doc:`define-models` for details on how to define that model. Ways to create jobs =================== @@ -43,12 +26,12 @@ Sync jobs are not supported by default, but can be easily enabled. Sync jobs can be created via either ``POST`` or ``GET``. You can pick whether your application will support sync ``POST``, sync ``GET``, both, or neither. Supporting ``GET`` makes it easier for people to assemble ad hoc jobs by writing the URL directly in their web browser. -However, due to unfixable web security reasons, ``GET`` jobs can be created by any malicious site on the Internet, and therefore should not be supported if the operation of your service is destructive, expensive, or dangerous if performed by unauthorized people. +However, due to unfixable web security limitations in the HTTP protocol, ``GET`` jobs can be created by any malicious site on the Internet, and therefore should not be supported if the operation of your service is destructive, expensive, or dangerous if performed by unauthorized people. -For each supported way to create a job, your application must provide a FastAPI dependency that reads input parameters via that method and returns a list of `UWSJobParameter` objects. +For each supported way to create a job, your application must provide a FastAPI dependency that reads input parameters via that method and returns the Pydantic model for parameters that you defined in :doc:`define-models`. Async POST dependency ---------------------- +===================== Supporting async ``POST`` is required. First, writing a FastAPI dependency that accepts the input parameters for your job as `form parameters `__. @@ -88,22 +71,19 @@ Here is an example for a SODA service that performs circular cutouts: ), ), ] = None, - ) -> list[UWSJobParameter]: - """Parse POST parameters into job parameters for a cutout.""" - params = [] - for i in id: - params.append(UWSJobParameter(paramater_id="id", value=i)) - for c in circle: - params.append(UWSJobParameter(parameter_id="circle", value=c)) - return params + ) -> CutoutParameters: + return CutoutParameters( + ids=id, + stencils=[CircleStencil.from_string(c) for c in circle], + ) This first declares the input parameters, with full documentation, as FastAPI ``Form`` parameters. - Note that the type is ``list[str]``, which allows the parameter to be specified multiple times. If the parameters for your service cannot be repeated, change this to `str` (or another appropriate basic type, such as `int`). -You do not need to do any input validation of the parameter values here. -This will be done later as part of converting the input parameters to your parameter model, as defined in :doc:`define-models`. +Then, it converts the form parameters into the Pydantic model for your job parameters. +Here, most of the work is done by the ``from_string`` static method on ``CircleStencil``, defined in :ref:`uws-model-parameters`. +This conversion should also perform any necessary input validation. Async POST configuration ------------------------ @@ -134,9 +114,9 @@ The ``summary`` and ``description`` attributes are only used to generate the API They contain a brief summary and a longer description of the async ``POST`` route and will be copied into the generated OpenAPI specification for the service. Sync POST ---------- +========= -Supporting sync ``POST`` is very similar: define a FastAPI dependency that accepts ``POST`` parameters and returns a list of `UWSJobParameter` objects, and then define a `UWSRoute` object including that dependency and pass it as the ``sync_post_route`` argument to `UWSAppSettings.build_uws_config`. +Supporting sync ``POST`` is very similar: define a FastAPI dependency that accepts ``POST`` parameters and returns your Pydantic parameter model, and then define a `UWSRoute` object including that dependency and pass it as the ``sync_post_route`` argument to `UWSAppSettings.build_uws_config`. By default, sync ``POST`` is not supported. Normally, the input parameters for sync ``POST`` will be the same as the input parameters for async ``POST``, so you can reuse the same FastAPI dependency. @@ -165,7 +145,7 @@ Here is an example for the same cutout service: This would then be passed as the ``sync_post_route`` argument. Sync GET --------- +======== Supporting sync ``GET`` follows the same pattern, but here you will need to define a separate dependency that takes query parameters rather than form parameters. Here is an example dependency for a cutout service: @@ -202,14 +182,14 @@ Here is an example dependency for a cutout service: ), ), ], - request: Request, - ) -> list[UWSJobParameter]: - """Parse GET parameters into job parameters for a cutout.""" - return [ - UWSJobParameter(parameter_id=k, value=v) - for k, v in request.query_params.items() - if k in {"id", "circle"} - ] + ) -> CutoutParameters: + return CutoutParameters( + ids=id, + stencils=[CircleStencil.from_string(c) for c in circle], + ) + +The body here is identical to the body of the dependency for ``POST``. +The difference is in how the parameters are defined (``Query`` vs. ``Form``). As in the other cases, you will then need to pass a `UWSRoute` object as the ``sync_get_route`` argument to `UWSAppSettings.build_uws_config`. Here is an example: @@ -238,5 +218,4 @@ This would then be passed as the ``sync_post_route`` argument. Next steps ========== -- Define the parameter models: :doc:`define-models` - Write the backend worker :doc:`write-backend` diff --git a/docs/user-guide/uws/define-models.rst b/docs/user-guide/uws/define-models.rst index 074fbf28..8ca87146 100644 --- a/docs/user-guide/uws/define-models.rst +++ b/docs/user-guide/uws/define-models.rst @@ -25,6 +25,19 @@ Unfortunately, the same model cannot be used for 1 and 3 even for simple applica Therefore, in the most general case, UWS applications must define three models for input parameters: the API model of parameters as provided by users via a JSON API, the model passed to the backend worker, and an XML model that flattens all parameters to strings. The input parameters for job creation via ``POST`` and ``GET`` are discussed in :doc:`define-inputs`. +What parameters look like +========================= + +In the IVOA UWS standard, input parameters for a job are a list of key/value pairs. +The value is always a string. +Other data types are not directly supported. + +In the Safir UWS support, however, job parameters are allowed to be arbitrary Pydantic models. +The only requirement is that it must be possible to serialize the parameters to a list of key/value pairs so that they can be returned by IVOA UWS standard routes. +In other words, the internal representation can be as complex as you wish, but the IVOA UWS standard requires the input parameters come from query or form parameters and be representable as a list of key/value pairs. + +Therefore, if your service needs a different data type as a parameter value, you will need to accept it as a string and then parse it into a more complex structure, and you will need to be able to convert your Pydantic model back to that list of strings. + .. _uws-worker-model: Worker parameter model @@ -99,13 +112,15 @@ Input validation will be done by the input parameter model. Single-valued parameters can use the syntax shown in `the vo-models documentation `__ to define the parameter ID if it differs from the attribute name. Optional multi-valued parameters, such as the above, have to use attribute names that match the XML parameter ID and the ``Field([])`` syntax to define the default to be an empty list, or you will get typing errors. +.. _uws-model-parameters: + Input parameter model ===================== Every UWS application must define a Pydantic model for its input parameters. This model must inherit from `ParametersModel`. -In addition to defining the parameter model, it must provide three methods: a class method named ``from_job_parameters`` that takes as input the list of `UWSJobParameter` objects and returns an instance of the model, an instance method named ``to_worker_parameters`` that converts the model to the one that will be passed to the backend worker (see :ref:`uws-worker-model`), and an instance method named ``to_xml_model`` that converts the model to the XML model (see :ref:`uws-xml-model`). +In addition to defining the parameter model, it must provide two methods: an instance method named ``to_worker_parameters`` that converts the model to the one that will be passed to the backend worker (see :ref:`uws-worker-model`), and an instance method named ``to_xml_model`` that converts the model to the XML model (see :ref:`uws-xml-model`). Often, the worker parameter model will look very similar to the input parameter model. They are still kept separate, since the input parameter model defines the API and the worker model defines the interface to the backend. @@ -139,21 +154,6 @@ Here is an example of a simple model for a cutout service: ids: list[str] = Field(..., title="Dataset IDs") stencils: list[CircleStencil] = Field(..., title="Cutout stencils") - @classmethod - def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self: - ids = [] - stencils = [] - try: - for param in params: - if param.parameter_id == "id": - ids.append(param.value) - else: - stencils.append(CircleStencil.from_string(param.value)) - except Exception as e: - msg = f"Invalid cutout parameter: {type(e).__name__}: {e!s}" - raise ParameterParseError(msg, params) from e - return cls(ids=ids, stencils=stencils) - def to_worker_parameters(self) -> WorkerCutout: return WorkerCutout(dataset_ids=self.ids, stencils=self.stencils) @@ -168,10 +168,15 @@ Notice that the input parameter model reuses some models from the worker (``Poin It also uses a different parameter for the dataset IDs (``ids`` instead of ``dataset_ids``), which is a trivial example of the sort of divergence one might see between input models and backend worker models. ``CutoutXmlParameters`` is defined in :ref:`uws-xml-model`. -The input models are also responsible for input parsing and validation (note the ``from_job_parameters`` and ``from_string`` methods) and converting to the worker model. +The ``from_string`` class method of ``CircleStencil`` is not used here directly. +This will be used when parsing query or form inputs into a Pydantic model. +See :doc:`define-inputs` for more details. + +The input models are also responsible for input parsing and validation, and converting to the worker and XML models. The worker model should be in a separate file and kept as simple as possible, since it has to be imported by the backend worker, which may not have the dependencies installed to be able to import other frontend code. -The XML model must use simple key/value pairs of strings to satisfy the UWS XML API, so ``to_xml_model`` may need to do some conversion from the model back to a string representation of the parameters. +The XML model must use simple key/value pairs of strings to satisfy the UWS XML API, so ``to_xml_model`` may need to do some conversion from the model to a string representation of the parameters. +This string representation should match the input accepted by the dependencies defined in :doc:`define-inputs`. Update the application configuration ==================================== @@ -183,8 +188,10 @@ In the example above, that would be ``CutoutParameters``. Set the ``job_summary_type`` argument to ``JobSummary[XmlModel]`` where ``XmlModel`` is whatever the class name of your XML parameter model is. In the example above, that would be ``JobSummary[CutoutXmlParameters]``. +(Although this type is theoretically knowable through type propagation, limitations in the pydantic-xml_ library require specifying it separately.) Next steps ========== +- Define the API parameters: :doc:`define-inputs` - Write the backend worker :doc:`write-backend` diff --git a/docs/user-guide/uws/index.rst b/docs/user-guide/uws/index.rst index 15137e60..e2b11717 100644 --- a/docs/user-guide/uws/index.rst +++ b/docs/user-guide/uws/index.rst @@ -12,6 +12,8 @@ Applications built with this framework have three components: #. A backend arq_ worker, possibly running on a different software stack, that does the work. #. A database arq_ worker that handles bookkeeping and result processing for the backend worker. +In addition, they use the Wobbly_ service to do the work of recording and retrieving job information so that each service doesn't have to manage its own underlying PostgreSQL database. + Incoming requests are turned into arq_ jobs, processed by the backend worker, uploaded to Google Cloud Storage, recorded in a database, and then returned to the user via a frontend that reads the result URLs and other metadata from the database. Frontend applications that use this library must depend on ``safir[uws]``. @@ -24,7 +26,7 @@ Guides :titlesonly: create-a-service - define-inputs define-models + define-inputs write-backend testing diff --git a/docs/user-guide/uws/testing.rst b/docs/user-guide/uws/testing.rst index 5fb7e67d..0558fcdd 100644 --- a/docs/user-guide/uws/testing.rst +++ b/docs/user-guide/uws/testing.rst @@ -13,64 +13,134 @@ Frontend testing fixtures ========================= The frontend of a UWS application assumes that arq_ will execute both jobs and the database worker that recovers the results of a job and stores them in the database. -During testing of the frontend, arq will not be running, and therefore this execution must be simulated. -This is done with the `MockUWSJobRunner` class, but it requires some setup. +It also assumes Wobbly will be available as an API for storing and retrieving jobs in the database. +During testing of the frontend, arq and Wobbly will not be running, and therefore this execution must be simulated. +This is done with the `MockWobbly` and `MockUWSJobRunner` classes, but it requires some setup. -Mock the arq queue ------------------- +Mock Wobbly +----------- -First, the application must be configured to use a `~safir.arq.MockArqQueue` class instead of one based on Redis. -This stores all queued jobs in memory and provides some test-only methods to manipulate them. +Add a development dependency on respx_ to the application. +Then, add a Wobbly mock fixture to :file:`tests/conftest.py`: -To do this, first set up a fixture in :file:`tests/conftest.py` that provides a mock arq queue: +.. code-block:: python + :caption: tests/conftest.py + + import pytest + import respx + from safir.testing.uws import MockWobbly, patch_wobbly + + from example.config import config + + + @pytest.fixture + def mock_wobbly(respx_mock: respx.Router) -> MockWobbly: + return patch_wobbly(respx_mock, str(config.wobbly_url)) + +Change ``example.config`` to the config module for your application. + +You will need to arrange for ``wobbly_url`` to be set in your application configuration during testing. +Normally the easiest way to do that is to set :samp:`{APPLICATION}_WOBBLY_URL` to some reasonable placeholder value such as ``http://example.com/wobbly`` in your application's tox configuration for running tests. + +Then, arrange for this fixture to be enabled by the test client for your application. +Usually the easiest way to do that is to make the ``mock_wobbly`` fixture a parameter to the ``app`` fixture that sets up the application for testing. + +Mock the request token +---------------------- + +All of the UWS routes provided by Safir expect to receive username and token information in the incoming request. +Normally these headers are generated by Gafaelfawr based on the authentication credentials of the request. +Inside the test suite, you will need to provide the headers that Gafaelfawr would have provided when making requests to the UWS routes. + +To do this, first create fixtures that define a test username and test service: .. code-block:: python :caption: tests/conftest.py import pytest - from safir.arq import MockArqQueue @pytest.fixture - def arq_queue() -> MockArqQueue: - return MockArqQueue() + def test_service() -> str: + return "test-service" -Then, configure the application to use that arq queue instead of the default one in the ``app`` fixture. + + @pytest.fixture + def test_username() -> str: + return "test-user" + +Then, create a fixture that returns a token for the Wobbly mock that encodes that username and service. +This allows the mock to recover the intended username and service of a request from the Safir UWS code to Wobbly. .. code-block:: python :caption: tests/conftest.py - :emphasize-lines: 8,14 - from collections.abc import AsyncIterator + import pytest + from safir.testing.uws import MockWobbly - from asgi_lifespan import LifespanManager - from fastapi import FastAPI - from safir.arq import MockArqQueue - from example import main - from example.config import uws + @pytest.fixture + def test_token(test_service: str, test_username: str) -> str: + return MockWobbly.make_token(test_service, test_username) + +Finally, configure the ``client`` fixture to send the appropriate headers by default. + +.. code-block:: python + :caption: tests/conftest.py + :emphasize-lines: 8,13-16 + + import pytest + from fastapi import FastAPI + from httpx import ASGITransport, AsyncClient @pytest_asyncio.fixture - async def app(arq_queue: MockArqQueue) -> AsyncIterator[FastAPI]: - async with LifespanManager(main.app): - uws.override_arq_queue(arq_queue) - yield main.app + async def client( + app: FastAPI, test_token: str, test_username: str + ) -> AsyncIterator[AsyncClient]: + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="https://example.com/", + headers={ + "X-Auth-Request-Token": test_token, + "X-Auth-Request-User": test_username, + }, + ) as client: + yield client + +Any tests that need to know the value of the username or service on whose behalf operations will be performed can use the ``test_username`` and ``test_service`` fixtures. +The token will be required for calls to `~safir.arq.MockArqQueue` (see below) and can be accessed via the ``test_token`` fixture. + +If a specific test needs to send a request from a different service or username (to test handling of multiple usernames, for instance), it should override the request headers to send a different token and username in ``X-Auth-Request-Token`` and ``X-Auth-Request-User``. +To generate a token for a different service and username pair, call `MockWobbly.make_token`. + +Mock the arq queue +------------------ + +First, the application must be configured to use a `~safir.arq.MockArqQueue` class instead of one based on Redis. +This stores all queued jobs in memory and provides some test-only methods to manipulate them. -Provide a test database ------------------------ +To do this, first set up a fixture in :file:`tests/conftest.py` that provides a mock arq queue: -UWS relies on database in which to store job information and results. -Follow the instructions in :doc:`/user-guide/database/testing` to use tox-docker_ to create a test PostgreSQL database, but skip the instructions there for initializing the database. -Instead, use the UWS library to initialize the resulting database: +.. code-block:: python + :caption: tests/conftest.py + + import pytest + from safir.arq import MockArqQueue + + + @pytest.fixture + def arq_queue() -> MockArqQueue: + return MockArqQueue() + +Then, configure the application to use that arq queue instead of the default one in the ``app`` fixture. .. code-block:: python :caption: tests/conftest.py - :emphasize-lines: 3,14-15 + :emphasize-lines: 8,14 from collections.abc import AsyncIterator - import structlog from asgi_lifespan import LifespanManager from fastapi import FastAPI from safir.arq import MockArqQueue @@ -81,8 +151,6 @@ Instead, use the UWS library to initialize the resulting database: @pytest_asyncio.fixture async def app(arq_queue: MockArqQueue) -> AsyncIterator[FastAPI]: - logger = structlog.get_logger("example") - await uws.initialize_uws_database(logger, reset=True) async with LifespanManager(main.app): uws.override_arq_queue(arq_queue) yield main.app @@ -120,8 +188,6 @@ This will simulate not only the execution of a backend worker that results some .. code-block:: python :caption: tests/conftest.py - from collections.abc import AsyncIterator - import pytest_asyncio from safir.arq import MockArqQueue from safir.testing.uws import MockUWSJobRunner @@ -130,11 +196,8 @@ This will simulate not only the execution of a backend worker that results some @pytest_asyncio.fixture - async def runner( - arq_queue: MockArqQueue, - ) -> AsyncIterator[MockUWSJobRunner]: - async with MockUWSJobRunner(config.uws_config, arq_queue) as runner: - yield runner + async def runner(arq_queue: MockArqQueue) -> MockUWSJobRunner: + return MockUWSJobRunner(config.uws_config, arq_queue) Writing a frontend test ======================= @@ -149,37 +212,35 @@ Here is an example of a test of a hypothetical cutout service. import pytest from httpx import AsyncClient from safir.testing.uws import MockUWSJobRunner + from safir.uws import JobResult @pytest.mark.asyncio async def test_create_job( - client: AsyncClient, runner: MockUWSJobRunner + client: AsyncClient, test_token: str, runner: MockUWSJobRunner ) -> None: r = await client.post( "/api/cutout/jobs", - headers={"X-Auth-Request-User": "someone"}, data={"ID": "1:2:band:value", "Pos": "CIRCLE 0 1 2"}, ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/api/cutout/jobs/1" - await runner.mark_in_progress("someone", "1") + await runner.mark_in_progress(test_token, "1") async def run_job() -> None: results = [ - UWSJobResult( - result_id="cutout", + JobResult( + id="cutout", url="s3://some-bucket/some/path", mime_type="application/fits", ) ] - await runner.mark_complete("someone", "1", results, delay=0.2) + await runner.mark_complete(test_token, "1", results, delay=0.2) _, r = await asyncio.gather( run_job(), client.get( - "/api/cutout/jobs/1", - headers={"X-Auth-Request-User": "someone"}, - params={"wait": 2, "phase": "EXECUTING"}, + "/api/cutout/jobs/1", params={"wait": 2, "phase": "EXECUTING"} ), ) assert r.status_code == 200 diff --git a/docs/user-guide/uws/write-backend.rst b/docs/user-guide/uws/write-backend.rst index 3a16e0e7..523aa418 100644 --- a/docs/user-guide/uws/write-backend.rst +++ b/docs/user-guide/uws/write-backend.rst @@ -141,7 +141,7 @@ See its documentation for a full list. There are two attributes that deserve special mention, however. The ``token`` attribute contains a delegated Gafaelfawr_ token to act on behalf of the user. -This token must be included in an :samp:`Authorization: bearer ` header in any web request that the backend makes to other Rubin Science Platform services. +This token must be included in an :samp:`Authorization: bearer {token}` header in any web request that the backend makes to other Rubin Science Platform services. The ``timeout`` attribute contains a `~datetime.timedelta` representation of the timeout for the job. The backend ideally should arrange to not exceed that total wall clock interval when executing. diff --git a/noxfile.py b/noxfile.py index 48a6a9ef..1fef82d0 100644 --- a/noxfile.py +++ b/noxfile.py @@ -20,11 +20,14 @@ nox.options.default_venv_backend = "uv" nox.options.reuse_existing_virtualenvs = True -# pip-installable dependencies for all the Safir modules. +# pip-installable dependencies for all the Safir modules. The local safir-arq +# apparently has to be installed after safir itself or the safir dependency +# resolution appears to replace the local install with the safir-arq package +# from PyPI. PIP_DEPENDENCIES = ( + ("-e", "./safir[arq,db,dev,gcs,kubernetes,redis,uws]"), ("-e", "./safir-logging"), ("-e", "./safir-arq"), - ("-e", "./safir[arq,db,dev,gcs,kubernetes,redis,uws]"), ) diff --git a/safir-arq/src/safir/arq/uws.py b/safir-arq/src/safir/arq/uws.py index daf59443..8fce0669 100644 --- a/safir-arq/src/safir/arq/uws.py +++ b/safir-arq/src/safir/arq/uws.py @@ -343,7 +343,7 @@ async def run( logger = logger.bind(run_id=info.run_id) start = datetime.now(tz=UTC) - await arq.enqueue("uws_job_started", info.job_id, start) + await arq.enqueue("uws_job_started", info.token, info.job_id, start) loop = asyncio.get_running_loop() try: async with asyncio.timeout(info.timeout.total_seconds()): @@ -358,7 +358,7 @@ async def run( ctx["pool"] = _restart_pool(pool) raise WorkerTimeoutError(elapsed, info.timeout) from None finally: - await arq.enqueue("uws_job_completed", info.job_id) + await arq.enqueue("uws_job_completed", info.token, info.job_id) # Since the worker is running sync jobs, run one job per pod since they # will be serialized anyway and no parallelism is possible. This also diff --git a/safir/pyproject.toml b/safir/pyproject.toml index a45d6eac..4cc80242 100644 --- a/safir/pyproject.toml +++ b/safir/pyproject.toml @@ -80,14 +80,11 @@ redis = [ "redis>4.5.2,<6", ] uws = [ - "alembic[tz]<2", - "asyncpg<1", "google-auth<3", "google-cloud-storage<3", "jinja2<4", "python-multipart", "safir-arq<7", - "sqlalchemy[asyncio]>=2.0.0,<3", "vo-models>=0.4.1,<1", ] diff --git a/safir/src/safir/testing/uws.py b/safir/src/safir/testing/uws.py index e66639b5..118e0f9a 100644 --- a/safir/src/safir/testing/uws.py +++ b/safir/src/safir/testing/uws.py @@ -1,23 +1,43 @@ -"""Mock UWS job executor for testing.""" +"""Mocks and functions for testing services using the Safir UWS support.""" from __future__ import annotations import asyncio +import json +from collections import defaultdict from datetime import UTC, datetime -from types import TracebackType -from typing import Literal, Self +from urllib.parse import parse_qs +import respx import structlog -from sqlalchemy.ext.asyncio import AsyncEngine +from httpx import AsyncClient, Request, Response from vo_models.uws import JobSummary +from vo_models.uws.types import ExecutionPhase from safir.arq import JobMetadata, JobResult, MockArqQueue -from safir.database import create_async_session, create_database_engine -from safir.uws import UWSConfig, UWSJob, UWSJobResult +from safir.arq.uws import WorkerResult +from safir.datetime import current_datetime, parse_isodatetime +from safir.uws import ( + Job, + JobCreate, + JobUpdateAborted, + JobUpdateCompleted, + JobUpdateError, + JobUpdateExecuting, + JobUpdateMetadata, + JobUpdateQueued, + SerializedJob, + UWSConfig, +) from safir.uws._service import JobService from safir.uws._storage import JobStore -__all__ = ["MockUWSJobRunner", "assert_job_summary_equal"] +__all__ = [ + "MockUWSJobRunner", + "MockWobbly", + "assert_job_summary_equal", + "patch_wobbly", +] def assert_job_summary_equal( @@ -44,6 +64,199 @@ def assert_job_summary_equal( assert seen_model.model_dump() == expected_model.model_dump() +class MockWobbly: + """Mock the Wobbly web service, which stores UWS job information. + + Use of this mock web service requires presentation of a token generated by + the `make_token` class method, which encodes username and service + information that would normally be taken from HTTP headers after + Gafaelfawr processing. + + Attributes + ---------- + jobs + Stored jobs, organized by service, username, and then job ID. + """ + + @staticmethod + def make_token(service: str, username: str) -> str: + """Create a fake internal token for Wobbly calls. + + Parameters + ---------- + service + Service name encoded in fake internal token. + username + Username encoded in fake internal token. + + Returns + ------- + str + Fake internal token. + """ + return f"gt-{service}.{username}" + + def __init__(self) -> None: + # Next available job ID. + self._job_id = 1 + + # Maps service to username to job ID to a job record. + self.jobs: defaultdict[str, defaultdict[str, dict[str, SerializedJob]]] + self.jobs = defaultdict(lambda: defaultdict(dict)) + + def create_job(self, request: Request) -> Response: + """Create a new job record.""" + service, username = self._get_auth(request) + job_create = JobCreate.model_validate_json(request.content) + job_id = str(self._job_id) + self._job_id += 1 + job = SerializedJob( + id=job_id, + service=service, + owner=username, + phase=ExecutionPhase.PENDING, + creation_time=current_datetime(), + **job_create.model_dump(), + ) + self.jobs[service][username][job_id] = job + return Response( + 201, + json=job.model_dump(mode="json"), + headers={"Location": str(request.url) + f"/{job_id}"}, + ) + + def delete_job(self, request: Request, *, job_id: str) -> Response: + """Delete a job.""" + service, username = self._get_auth(request) + if job_id not in self.jobs[service][username]: + return Response(404) + del self.jobs[service][username][job_id] + return Response(204) + + def get_job(self, request: Request, *, job_id: str) -> Response: + """Retrieve a job.""" + service, username = self._get_auth(request) + job = self.jobs[service][username].get(job_id) + if not job: + return Response(404) + return Response(200, json=job.model_dump(mode="json")) + + def list_jobs(self, request: Request) -> Response: + """List jobs matching the search parameters. + + Cursors are not supported. ``limit`` is, but does not result in a + paginated response. + """ + service, username = self._get_auth(request) + + # Parse query. + query = parse_qs(request.url.query.decode()) + phases = set() + if "phase" in query: + phases = set(query["phase"]) + since = None + if "since" in query: + since = parse_isodatetime(query["since"][0]) + + # Perform the search. + results = [] + for job in self.jobs[service][username].values(): + if phases and job.phase not in phases: + continue + if since and job.creation_time <= since: + continue + results.append(job) + + # Sort the results and limit them if needed. + json_results = [ + j.model_dump(mode="json") + for j in sorted( + results, key=lambda j: (j.creation_time, j.id), reverse=True + ) + ] + if "limit" in query: + limit = int(query["limit"][0]) + if len(json_results) > limit: + json_results = json_results[:limit] + + # Return the response. + return Response(200, json=json_results) + + def update_job(self, request: Request, *, job_id: str) -> Response: + """Make an update to a job record.""" + service, username = self._get_auth(request) + job = self.jobs[service][username].get(job_id) + if not job: + return Response(404) + body = json.loads(request.content) + + # First handle the only case without a phase, which is a metadata + # update. + if "phase" not in body or not body["phase"]: + update = JobUpdateMetadata.model_validate(body) + job.destruction_time = update.destruction_time + job.execution_duration = update.execution_duration + return Response(200, json=job.model_dump(mode="json")) + + # Otherwise, handle all the phase modification cases. + match body["phase"]: + case ExecutionPhase.ABORTED: + _ = JobUpdateAborted.model_validate(body) + if job.start_time: + job.end_time = current_datetime() + case ExecutionPhase.COMPLETED: + completed_update = JobUpdateCompleted.model_validate(body) + job.end_time = current_datetime() + job.results = completed_update.results + case ExecutionPhase.EXECUTING: + executing_update = JobUpdateExecuting.model_validate(body) + start_time = executing_update.start_time.replace(microsecond=0) + job.start_time = start_time + case ExecutionPhase.ERROR: + error_update = JobUpdateError.model_validate(body) + job.end_time = current_datetime() + job.errors = error_update.errors + case ExecutionPhase.QUEUED: + queued_update = JobUpdateQueued.model_validate(body) + job.message_id = queued_update.message_id + job.phase = ExecutionPhase[body["phase"]] + return Response(200, json=job.model_dump(mode="json")) + + def _get_auth(self, request: Request) -> tuple[str, str]: + """Parse the fake internal token into service and username.""" + auth_type, token = request.headers["Authorization"].split() + assert auth_type.lower() == "bearer" + assert token.startswith("gt-") + service, username = token[3:].split(".", 1) + return (service, username) + + +def patch_wobbly(respx_mock: respx.Router, wobbly_url: str) -> MockWobbly: + """Set up the mock for a Wobbly server. + + Parameters + ---------- + respx_mock + Mock router. + wobby_url + Base URL on which Wobbly is "listening." + + Returns + ------- + MockWobbly + Mock Wobbly service. + """ + wobbly_url = wobbly_url.rstrip("/") + mock = MockWobbly() + respx_mock.get(wobbly_url + "/jobs").mock(side_effect=mock.list_jobs) + respx_mock.post(wobbly_url + "/jobs").mock(side_effect=mock.create_job) + job_url = rf"{wobbly_url}/jobs/(?P[^/]+)" + respx_mock.get(url__regex=job_url).mock(side_effect=mock.get_job) + respx_mock.delete(url__regex=job_url).mock(side_effect=mock.delete_job) + respx_mock.patch(url__regex=job_url).mock(side_effect=mock.update_job) + return mock + + class MockUWSJobRunner: """Simulate execution of jobs with a mock queue. @@ -52,10 +265,6 @@ class MockUWSJobRunner: manually updating state in the mock queue and running the UWS database worker functions that normally would be run automatically by the queue. - This class wraps that functionality in an async context manager. An - instance of it is normally provided as a fixture, initialized with the - same test objects as the test suite. - Parameters ---------- config @@ -65,49 +274,26 @@ class MockUWSJobRunner: """ def __init__(self, config: UWSConfig, arq_queue: MockArqQueue) -> None: - self._config = config self._arq = arq_queue - self._engine: AsyncEngine - self._store: JobStore - self._service: JobService - async def __aenter__(self) -> Self: - """Create a database session and the underlying service.""" # This duplicates some of the code in UWSDependency to avoid needing # to set up the result store or to expose UWSFactory outside of the # Safir package internals. - self._engine = create_database_engine( - self._config.database_url, - self._config.database_password, - isolation_level="REPEATABLE READ", - ) - session = await create_async_session(self._engine) - self._store = JobStore(session) + self._store = JobStore(config, AsyncClient()) self._service = JobService( - config=self._config, + config=config, arq_queue=self._arq, storage=self._store, logger=structlog.get_logger("uws"), ) - return self - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> Literal[False]: - """Close the database engine and session.""" - await self._engine.dispose() - return False - - async def get_job_metadata( - self, username: str, job_id: str - ) -> JobMetadata: + async def get_job_metadata(self, token: str, job_id: str) -> JobMetadata: """Get the arq job metadata for a job. Parameters ---------- + token + Token for the user. job_id UWS job ID. @@ -116,15 +302,17 @@ async def get_job_metadata( JobMetadata arq job metadata. """ - job = await self._service.get(username, job_id) + job = await self._service.get(token, job_id) assert job.message_id return await self._arq.get_job_metadata(job.message_id) - async def get_job_result(self, username: str, job_id: str) -> JobResult: + async def get_job_result(self, token: str, job_id: str) -> JobResult: """Get the arq job result for a job. Parameters ---------- + token + Token for the user. job_id UWS job ID. @@ -133,19 +321,19 @@ async def get_job_result(self, username: str, job_id: str) -> JobResult: JobMetadata arq job metadata. """ - job = await self._service.get(username, job_id) + job = await self._service.get(token, job_id) assert job.message_id return await self._arq.get_job_result(job.message_id) async def mark_in_progress( - self, username: str, job_id: str, *, delay: float | None = None - ) -> UWSJob: + self, token: str, job_id: str, *, delay: float | None = None + ) -> Job: """Mark a queued job in progress. Parameters ---------- - username - Owner of job. + token + Token for the user. job_id Job ID. delay @@ -153,31 +341,31 @@ async def mark_in_progress( Returns ------- - UWSJob + Job Record of the job. """ if delay: await asyncio.sleep(delay) - job = await self._service.get(username, job_id) + job = await self._service.get(token, job_id) assert job.message_id await self._arq.set_in_progress(job.message_id) - await self._store.mark_executing(job_id, datetime.now(tz=UTC)) - return await self._service.get(username, job_id) + await self._store.mark_executing(token, job_id, datetime.now(tz=UTC)) + return await self._service.get(token, job_id) async def mark_complete( self, - username: str, + token: str, job_id: str, - results: list[UWSJobResult] | Exception, + results: list[WorkerResult] | Exception, *, delay: float | None = None, - ) -> UWSJob: + ) -> Job: """Mark an in progress job as complete. Parameters ---------- - username - Owner of job. + token + Token for the user. job_id Job ID. results @@ -187,14 +375,14 @@ async def mark_complete( Returns ------- - UWSJob + Job Record of the job. """ if delay: await asyncio.sleep(delay) - job = await self._service.get(username, job_id) + job = await self._service.get(token, job_id) assert job.message_id await self._arq.set_complete(job.message_id, result=results) job_result = await self._arq.get_job_result(job.message_id) - await self._store.mark_completed(job_id, job_result) - return await self._service.get(username, job_id) + await self._store.mark_completed(token, job_id, job_result) + return await self._service.get(token, job_id) diff --git a/safir/src/safir/uws/__init__.py b/safir/src/safir/uws/__init__.py index f2bd0ce7..346dba1a 100644 --- a/safir/src/safir/uws/__init__.py +++ b/safir/src/safir/uws/__init__.py @@ -2,40 +2,42 @@ from ._app import UWSApplication from ._config import UWSAppSettings, UWSConfig, UWSRoute -from ._exceptions import ( - DatabaseSchemaError, - MultiValuedParameterError, - ParameterError, - ParameterParseError, - UsageError, - UWSError, -) +from ._exceptions import ParameterError, UsageError, UWSError from ._models import ( - ErrorCode, + Job, + JobBase, + JobCreate, + JobError, + JobResult, + JobUpdateAborted, + JobUpdateCompleted, + JobUpdateError, + JobUpdateExecuting, + JobUpdateMetadata, + JobUpdateQueued, ParametersModel, - UWSJob, - UWSJobError, - UWSJobParameter, - UWSJobResult, + SerializedJob, ) -from ._schema import UWSSchemaBase __all__ = [ - "DatabaseSchemaError", - "ErrorCode", - "MultiValuedParameterError", + "Job", + "JobBase", + "JobCreate", + "JobError", + "JobResult", + "JobUpdateAborted", + "JobUpdateCompleted", + "JobUpdateError", + "JobUpdateExecuting", + "JobUpdateMetadata", + "JobUpdateQueued", "ParameterError", - "ParameterParseError", "ParametersModel", + "SerializedJob", "UWSAppSettings", "UWSApplication", "UWSConfig", "UWSError", - "UWSJob", - "UWSJobError", - "UWSJobParameter", - "UWSJobResult", "UWSRoute", - "UWSSchemaBase", "UsageError", ] diff --git a/safir/src/safir/uws/_app.py b/safir/src/safir/uws/_app.py index 6661f396..8cfa3a45 100644 --- a/safir/src/safir/uws/_app.py +++ b/safir/src/safir/uws/_app.py @@ -2,11 +2,8 @@ from __future__ import annotations -from dataclasses import asdict -from pathlib import Path from typing import Any -from arq import cron from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import PlainTextResponse @@ -14,21 +11,15 @@ from safir.arq import ArqQueue, WorkerSettings from safir.arq.uws import UWS_QUEUE_NAME -from safir.database import ( - create_database_engine, - initialize_database, - is_database_current, - stamp_database_async, -) from safir.middleware.ivoa import ( CaseInsensitiveFormMiddleware, CaseInsensitiveQueryMiddleware, ) from ._config import UWSConfig -from ._constants import UWS_DATABASE_TIMEOUT, UWS_EXPIRE_JOBS_SCHEDULE +from ._constants import UWS_DATABASE_TIMEOUT from ._dependencies import uws_dependency -from ._exceptions import DatabaseSchemaError, UWSError +from ._exceptions import UWSError from ._handlers import ( install_async_post_handler, install_availability_handler, @@ -36,11 +27,9 @@ install_sync_post_handler, uws_router, ) -from ._schema import UWSSchemaBase from ._workers import ( close_uws_worker_context, create_uws_worker_context, - uws_expire_jobs, uws_job_completed, uws_job_started, ) @@ -51,7 +40,7 @@ async def _uws_error_handler( request: Request, exc: UWSError ) -> PlainTextResponse: - response = f"{exc.error_code.value}: {exc!s}\n" + response = f"{exc.error_code}: {exc!s}\n" if exc.detail: response += "\n{exc.detail}" return PlainTextResponse(response, status_code=exc.status_code) @@ -82,13 +71,7 @@ class UWSApplication: def __init__(self, config: UWSConfig) -> None: self._config = config - def build_worker( - self, - logger: BoundLogger, - *, - check_schema: bool = False, - alembic_config_path: Path = Path("alembic.ini"), - ) -> WorkerSettings: + def build_worker(self, logger: BoundLogger) -> WorkerSettings: """Construct an arq worker configuration for the UWS worker. All UWS job status and results must be stored in the underlying @@ -109,23 +92,10 @@ def build_worker( ---------- logger Logger to use for messages. - check_schema - Whether to check the database schema version with Alembic on - startup. - alembic_config_path - When checking the schema, use this path to the Alembic - configuration. """ async def startup(ctx: dict[Any, Any]) -> None: - ctx.update( - await create_uws_worker_context( - self._config, - logger, - check_schema=check_schema, - alembic_config_path=alembic_config_path, - ) - ) + ctx.update(await create_uws_worker_context(self._config, logger)) async def shutdown(ctx: dict[Any, Any]) -> None: await close_uws_worker_context(ctx) @@ -133,14 +103,6 @@ async def shutdown(ctx: dict[Any, Any]) -> None: # Running 10 jobs simultaneously is the arq default as of arq 0.26.0 # and seems reasonable for database workers. return WorkerSettings( - cron_jobs=[ - cron( - uws_expire_jobs, - unique=True, - timeout=UWS_DATABASE_TIMEOUT, - **asdict(UWS_EXPIRE_JOBS_SCHEDULE), - ) - ], functions=[uws_job_started, uws_job_completed], redis_settings=self._config.arq_redis_settings, job_completion_wait=UWS_DATABASE_TIMEOUT, @@ -151,71 +113,14 @@ async def shutdown(ctx: dict[Any, Any]) -> None: on_shutdown=shutdown, ) - async def initialize_fastapi( - self, - logger: BoundLogger | None = None, - *, - check_schema: bool = False, - alembic_config_path: Path = Path("alembic.ini"), - ) -> None: + async def initialize_fastapi(self) -> None: """Initialize the UWS subsystem for FastAPI applications. This must be called before any UWS routes are accessed, normally from the lifespan function of the FastAPI application. - - Parameters - ---------- - logger - Logger to use to report any problems. - check_schema - If `True`, check whether the database schema for the UWS database - is up to date using Alembic. - alembic_config_path - When checking the schema, use this path to the Alembic - configuration. - - Raises - ------ - DatabaseSchemaError - Raised if the UWS database schema is out of date. """ - if check_schema: - if not await self.is_schema_current(logger, alembic_config_path): - raise DatabaseSchemaError("UWS database schema out of date") await uws_dependency.initialize(self._config) - async def initialize_uws_database( - self, - logger: BoundLogger, - *, - reset: bool = False, - use_alembic: bool = False, - alembic_config_path: Path = Path("alembic.ini"), - ) -> None: - """Initialize the UWS database. - - Parameters - ---------- - logger - Logger to use. - reset - If `True`, also delete all data in the database. - use_alembic - Whether to stamp the UWS database with Alembic. - alembic_config_path - When stamping the database, use this path to the Alembic - configuration. - """ - engine = create_database_engine( - self._config.database_url, self._config.database_password - ) - await initialize_database( - engine, logger, schema=UWSSchemaBase.metadata, reset=reset - ) - if use_alembic: - await stamp_database_async(engine, alembic_config_path) - await engine.dispose() - def install_error_handlers(self, app: FastAPI) -> None: """Install error handlers that follow DALI and UWS conventions. @@ -276,28 +181,6 @@ def install_middleware(self, app: FastAPI) -> None: app.add_middleware(CaseInsensitiveFormMiddleware) app.add_middleware(CaseInsensitiveQueryMiddleware) - async def is_schema_current( - self, - logger: BoundLogger | None = None, - config_path: Path = Path("alembic.ini"), - ) -> bool: - """Check that the database schema is current using Alembic. - - Parameters - ---------- - logger - Logger to use to report any problems. - config_path - Path to the Alembic configuration. - """ - engine = create_database_engine( - self._config.database_url, self._config.database_password - ) - try: - return await is_database_current(engine, logger, config_path) - finally: - await engine.dispose() - def override_arq_queue(self, arq_queue: ArqQueue) -> None: """Change the arq used by the FastAPI route handlers. @@ -314,6 +197,7 @@ async def shutdown_fastapi(self) -> None: """Shut down the UWS subsystem for FastAPI applications. This should be called during application shutdown, normally from the - lifespan function of the FastAPI application. + lifespan function of the FastAPI application. Currently, this does + nothing, but it remains as a hook in case some shutdown is required in + the future. """ - await uws_dependency.aclose() diff --git a/safir/src/safir/uws/_config.py b/safir/src/safir/uws/_config.py index 7c66494b..a9f58abd 100644 --- a/safir/src/safir/uws/_config.py +++ b/safir/src/safir/uws/_config.py @@ -8,26 +8,20 @@ from typing import TypeAlias from arq.connections import RedisSettings -from pydantic import Field, SecretStr -from pydantic_core import Url +from pydantic import Field, HttpUrl, SecretStr from pydantic_settings import BaseSettings from vo_models.uws import JobSummary from safir.arq import ArqMode, build_arq_redis_settings -from safir.pydantic import ( - EnvAsyncPostgresDsn, - EnvRedisDsn, - HumanTimedelta, - SecondsTimedelta, -) +from safir.pydantic import EnvRedisDsn, HumanTimedelta, SecondsTimedelta -from ._models import ParametersModel, UWSJob, UWSJobParameter +from ._models import Job, ParametersModel -DestructionValidator: TypeAlias = Callable[[datetime, UWSJob], datetime] +DestructionValidator: TypeAlias = Callable[[datetime, Job], datetime] """Type for a validator for a new destruction time.""" ExecutionDurationValidator: TypeAlias = Callable[ - [timedelta, UWSJob], timedelta + [timedelta | None, Job], timedelta | None ] """Type for a validator for a new execution duration.""" @@ -42,7 +36,7 @@ class UWSRoute: """Defines a FastAPI dependency to get the UWS job parameters.""" - dependency: Callable[..., Coroutine[None, None, list[UWSJobParameter]]] + dependency: Callable[..., Coroutine[None, None, ParametersModel]] """Type for a dependency that gathers parameters for a job.""" summary: str @@ -71,13 +65,9 @@ class encapsulates the configuration of the UWS component that may vary by """Route configuration for creating an async job via POST. The FastAPI dependency included in this object should expect POST - parameters and return a list of `~safir.uws.UWSJobParameter` objects - representing the job parameters. + parameters and return the Pydantic model of the job parameters. """ - database_url: str | Url - """URL for the metadata database.""" - execution_duration: timedelta """Maximum execution time in seconds. @@ -117,12 +107,12 @@ class encapsulates the configuration of the UWS component that may vary by with this email. """ + wobbly_url: str | HttpUrl + """URL to the Wobbly UWS job tracking API.""" + worker: str """Name of the backend worker to call to execute a job.""" - database_password: SecretStr | None = None - """Password for the database.""" - slack_webhook: SecretStr | None = None """Slack incoming webhook for reporting errors.""" @@ -130,18 +120,16 @@ class encapsulates the configuration of the UWS component that may vary by """Route configuration for creating a sync job via GET. The FastAPI dependency included in this object should expect GET - parameters and return a list of `~safir.uws.UWSJobParameter` objects - representing the job parameters. If `None`, no route to create a job via - sync GET will be created. + parameters and the Pydantic model of the job parameters. If `None`, no + route to create a job via sync GET will be created. """ sync_post_route: UWSRoute | None = None """Route configuration for creating a sync job via POST. The FastAPI dependency included in this object should expect POST - parameters and return a list of `~safir.uws.UWSJobParameter` objects - representing the job parameters. If `None`, no route to create a job via - sync POST will be created. + parameters and return the Pydantic model of the job parameters. If `None`, + no route to create a job via sync POST will be created. """ sync_timeout: timedelta = timedelta(minutes=5) @@ -195,16 +183,6 @@ class UWSAppSettings(BaseSettings): description="Password of Redis server to use for the arq queue", ) - database_url: EnvAsyncPostgresDsn = Field( - ..., - title="PostgreSQL DSN", - description="DSN of PostgreSQL database for UWS job tracking", - ) - - database_password: SecretStr | None = Field( - None, title="Password for UWS job database" - ) - grace_period: SecondsTimedelta = Field( timedelta(seconds=30), title="Grace period for jobs", @@ -251,6 +229,16 @@ class UWSAppSettings(BaseSettings): ), ) + wobbly_url: HttpUrl = Field( + ..., + title="Wobbly URL", + description="URL to Wobbly UWS job tracking API", + ) + + database_password: SecretStr | None = Field( + None, title="Password for UWS job database" + ) + @property def arq_redis_settings(self) -> RedisSettings: """Redis settings for arq.""" @@ -285,8 +273,8 @@ def build_uws_config( async_post_route Route configuration for job parameters for an async job via POST. The FastAPI dependency included in this object should expect - POST parameters and return a list of `~safir.uws.UWSJobParameter` - objects representing the job parameters. + POST parameters and return a Pydantic model representing the job + parameters. job_summary_type Type representing the XML job summary type, qualified with an appropriate subclass of `~vo_models.uws.models.Parameters`. That @@ -301,15 +289,13 @@ def build_uws_config( sync_get_route Route configuration for creating a sync job via GET. The FastAPI dependency included in this object should expect GET parameters - and return a list of `~safir.uws.UWSJobParameter` objects - representing the job parameters. If `None`, no route to create a - job via sync GET will be created. + and return a Pydantic model representing the job parameters. If + `None`, no route to create a job via sync GET will be created. sync_post_route Route configuration for creating a sync job via POST. The FastAPI dependency included in this object should expect POST parameters - and return a list of `~safir.uws.UWSJobParameter` objects - representing the job parameters. If `None`, no route to create a - job via sync POST will be created. + and return a Pydantic model representing the job parameters. If + `None`, no route to create a job via sync POST will be created. url_lifetime How long result URLs should be valid for. validate_destruction @@ -368,8 +354,6 @@ def uws_config(self) -> UWSConfig: parameters_type=parameters_type, signing_service_account=self.service_account, worker=worker, - database_url=self.database_url, - database_password=self.database_password, slack_webhook=slack_webhook, sync_timeout=self.sync_timeout, async_post_route=async_post_route, @@ -379,4 +363,5 @@ def uws_config(self) -> UWSConfig: validate_destruction=validate_destruction, validate_execution_duration=validate_execution_duration, wait_timeout=wait_timeout, + wobbly_url=self.wobbly_url, ) diff --git a/safir/src/safir/uws/_constants.py b/safir/src/safir/uws/_constants.py index bd1e67f7..8d268c74 100644 --- a/safir/src/safir/uws/_constants.py +++ b/safir/src/safir/uws/_constants.py @@ -4,13 +4,11 @@ from datetime import timedelta -from arq.cron import Options - __all__ = [ "JOB_RESULT_TIMEOUT", "JOB_STOP_TIMEOUT", "UWS_DATABASE_TIMEOUT", - "UWS_EXPIRE_JOBS_SCHEDULE", + "WOBBLY_REQUEST_TIMEOUT", ] JOB_RESULT_TIMEOUT = timedelta(seconds=5) @@ -25,13 +23,5 @@ This should match the default Kubernetes grace period for a pod to shut down. """ -UWS_EXPIRE_JOBS_SCHEDULE = Options( - month=None, - day=None, - weekday=None, - hour=None, - minute=5, - second=0, - microsecond=0, -) -"""Schedule for job expiration cron job, as `arq.cron.cron` parameters.""" +WOBBLY_REQUEST_TIMEOUT = 20 +"""Timeout in seconds for Wobbly HTTP requests.""" diff --git a/safir/src/safir/uws/_dependencies.py b/safir/src/safir/uws/_dependencies.py index 417ea760..c5dcc29a 100644 --- a/safir/src/safir/uws/_dependencies.py +++ b/safir/src/safir/uws/_dependencies.py @@ -5,15 +5,14 @@ to individual route handlers, which in turn can create other needed objects. """ -from collections.abc import AsyncIterator from typing import Annotated, Literal from fastapi import Depends, Form, Query -from sqlalchemy.ext.asyncio import AsyncEngine, async_scoped_session +from httpx import AsyncClient from structlog.stdlib import BoundLogger from safir.arq import ArqMode, ArqQueue, MockArqQueue, RedisArqQueue -from safir.database import create_async_session, create_database_engine +from safir.dependencies.http_client import http_client_dependency from safir.dependencies.logger import logger_dependency from ._config import UWSConfig @@ -40,33 +39,25 @@ class UWSFactory: UWS configuration. arq arq queue to use. - session - Database session. result_store Signed URL generator for results. logger Logger to use. - - Attributes - ---------- - session - Database session. This is exposed primarily for the test suite. It - shouldn't be necessary for other code to use it directly. """ def __init__( self, *, config: UWSConfig, - arq: ArqQueue, - session: async_scoped_session, result_store: ResultStore, + arq: ArqQueue, + http_client: AsyncClient, logger: BoundLogger, ) -> None: - self.session = session self._config = config - self._arq = arq self._result_store = result_store + self._arq = arq + self._http_client = http_client self._logger = logger def create_result_store(self) -> ResultStore: @@ -84,11 +75,11 @@ def create_job_service(self) -> JobService: def create_job_store(self) -> JobStore: """Create a new UWS job store.""" - return JobStore(self.session) + return JobStore(self._config, self._http_client) def create_templates(self) -> UWSTemplates: """Create a new XML renderer for responses.""" - return UWSTemplates(self._result_store) + return UWSTemplates() class UWSDependency: @@ -97,33 +88,22 @@ class UWSDependency: def __init__(self) -> None: self._arq: ArqQueue | None = None self._config: UWSConfig - self._engine: AsyncEngine - self._session: async_scoped_session self._result_store: ResultStore async def __call__( - self, logger: Annotated[BoundLogger, Depends(logger_dependency)] - ) -> AsyncIterator[UWSFactory]: + self, + http_client: Annotated[AsyncClient, Depends(http_client_dependency)], + logger: Annotated[BoundLogger, Depends(logger_dependency)], + ) -> UWSFactory: if not self._arq: raise RuntimeError("UWSDependency not initialized") - try: - yield UWSFactory( - config=self._config, - arq=self._arq, - session=self._session, - result_store=self._result_store, - logger=logger, - ) - finally: - # Following the recommendations in the SQLAlchemy documentation, - # each session is scoped to a single web request. However, this - # all uses the same async_scoped_session object, so should share - # an underlying engine and connection pool. - await self._session.remove() - - async def aclose(self) -> None: - """Shut down the UWS subsystem.""" - await self._engine.dispose() + return UWSFactory( + config=self._config, + result_store=self._result_store, + arq=self._arq, + http_client=http_client, + logger=logger, + ) async def initialize(self, config: UWSConfig) -> None: """Initialize the UWS subsystem. @@ -141,12 +121,6 @@ async def initialize(self, config: UWSConfig) -> None: self._arq = await RedisArqQueue.initialize(settings) else: self._arq = MockArqQueue() - self._engine = create_database_engine( - config.database_url, - config.database_password, - isolation_level="REPEATABLE READ", - ) - self._session = await create_async_session(self._engine) def override_arq_queue(self, arq_queue: ArqQueue) -> None: """Change the arq used in subsequent invocations. diff --git a/safir/src/safir/uws/_exceptions.py b/safir/src/safir/uws/_exceptions.py index 53c85081..edd0f9ef 100644 --- a/safir/src/safir/uws/_exceptions.py +++ b/safir/src/safir/uws/_exceptions.py @@ -15,19 +15,16 @@ SlackMessage, SlackTextBlock, SlackTextField, + SlackWebException, ) from safir.slack.webhook import SlackIgnoredException -from ._models import ErrorCode, UWSJobError, UWSJobParameter +from ._models import JobError __all__ = [ "DataMissingError", - "DatabaseSchemaError", "InvalidPhaseError", - "MultiValuedParameterError", "ParameterError", - "ParameterParseError", - "PermissionDeniedError", "SyncJobFailedError", "SyncJobNoResultsError", "SyncJobTimeoutError", @@ -35,13 +32,10 @@ "UWSError", "UnknownJobError", "UsageError", + "WobblyError", ] -class DatabaseSchemaError(Exception): - """Some problem was detected in the UWS database schema.""" - - class UWSError(SlackIgnoredException): """An error with an associated error code. @@ -60,7 +54,7 @@ class UWSError(SlackIgnoredException): """ def __init__( - self, error_code: ErrorCode, message: str, detail: str | None = None + self, error_code: str, message: str, detail: str | None = None ) -> None: super().__init__(message) self.error_code = error_code @@ -68,27 +62,11 @@ def __init__( self.status_code = 400 -class MultiValuedParameterError(UWSError): - """Multiple values not allowed for this parameter.""" - - def __init__(self, message: str) -> None: - super().__init__(ErrorCode.MULTIVALUED_PARAM_NOT_SUPPORTED, message) - self.status_code = 422 - - -class PermissionDeniedError(UWSError): - """User does not have access to this resource.""" - - def __init__(self, message: str) -> None: - super().__init__(ErrorCode.AUTHORIZATION_ERROR, message) - self.status_code = 403 - - class SyncJobFailedError(UWSError): """A sync job failed.""" - def __init__(self, error: UWSJobError) -> None: - super().__init__(error.error_code, error.message, error.detail) + def __init__(self, error: JobError) -> None: + super().__init__(error.code, error.message, error.detail) self.status_code = 500 @@ -97,7 +75,7 @@ class SyncJobNoResultsError(UWSError): def __init__(self) -> None: msg = "Job completed but produced no results" - super().__init__(ErrorCode.ERROR, msg) + super().__init__("Error", msg) self.status_code = 500 @@ -106,7 +84,7 @@ class SyncJobTimeoutError(UWSError): def __init__(self, timeout: timedelta) -> None: msg = f"Job did not complete in {timeout.total_seconds()}s" - super().__init__(ErrorCode.ERROR, msg) + super().__init__("Error", msg) self.status_code = 500 @@ -150,7 +128,7 @@ class TaskError(SlackException): def __init__( self, *, - error_code: ErrorCode, + error_code: str, error_type: ErrorType, message: str, detail: str | None = None, @@ -186,13 +164,13 @@ def from_worker_error(cls, exc: WorkerError) -> Self: slack_ignore = False match exc.error_type: case WorkerErrorType.FATAL: - error_code = ErrorCode.ERROR + error_code = "Error" error_type = ErrorType.FATAL case WorkerErrorType.TRANSIENT: - error_code = ErrorCode.SERVICE_UNAVAILABLE + error_code = "ServiceUnavailable" error_type = ErrorType.TRANSIENT case WorkerErrorType.USAGE: - error_code = ErrorCode.USAGE_ERROR + error_code = "UsageError" error_type = ErrorType.FATAL slack_ignore = True return cls( @@ -205,15 +183,15 @@ def from_worker_error(cls, exc: WorkerError) -> Self: slack_ignore=slack_ignore, ) - def to_job_error(self) -> UWSJobError: + def to_job_error(self) -> JobError: """Convert to a `~safir.uws._models.UWSJobError`.""" if self._traceback and self._detail: detail: str | None = self._detail + "\n\n" + self._traceback else: detail = self._detail or self._traceback - return UWSJobError( - error_code=self._error_code, - error_type=self._error_type, + return JobError( + code=self._error_code, + type=self._error_type, message=self._message, detail=detail, ) @@ -245,7 +223,7 @@ class UsageError(UWSError): """Invalid parameters were passed to a UWS API.""" def __init__(self, message: str, detail: str | None = None) -> None: - super().__init__(ErrorCode.USAGE_ERROR, message, detail) + super().__init__("UsageError", message, detail) self.status_code = 422 @@ -253,7 +231,7 @@ class DataMissingError(UWSError): """The data requested does not exist for that job.""" def __init__(self, message: str) -> None: - super().__init__(ErrorCode.USAGE_ERROR, message) + super().__init__("UsageError", message) self.status_code = 404 @@ -265,18 +243,13 @@ class ParameterError(UsageError): """Unsupported value passed to a parameter.""" -class ParameterParseError(ParameterError): - """UWS job parameters could not be parsed.""" - - def __init__(self, message: str, params: list[UWSJobParameter]) -> None: - detail = "\n".join(f"{p.parameter_id}={p.value}" for p in params) - super().__init__(message, detail) - self.params = params - - class UnknownJobError(DataMissingError): """The named job could not be found in the database.""" def __init__(self, job_id: str) -> None: super().__init__(f"Job {job_id} not found") self.job_id = job_id + + +class WobblyError(SlackWebException): + """An error occurred making a request to Wobbly.""" diff --git a/safir/src/safir/uws/_handlers.py b/safir/src/safir/uws/_handlers.py index d97f36d8..5a085ed4 100644 --- a/safir/src/safir/uws/_handlers.py +++ b/safir/src/safir/uws/_handlers.py @@ -11,7 +11,6 @@ from fastapi import APIRouter, Depends, Form, Query, Request, Response from fastapi.responses import PlainTextResponse, RedirectResponse -from structlog.stdlib import BoundLogger from vo_models.uws import Jobs, JobSummary, Results from vo_models.uws.types import ExecutionPhase @@ -19,7 +18,6 @@ from safir.dependencies.gafaelfawr import ( auth_delegated_token_dependency, auth_dependency, - auth_logger_dependency, ) from safir.pydantic import IvoaIsoDatetime from safir.slack.webhook import SlackRouteErrorHandler @@ -32,7 +30,7 @@ uws_dependency, ) from ._exceptions import DataMissingError -from ._models import UWSJobParameter +from ._models import ParametersModel uws_router = APIRouter(route_class=SlackRouteErrorHandler) """FastAPI router for all external handlers.""" @@ -85,12 +83,12 @@ async def get_job_list( description="Return at most the given number of jobs", ), ] = None, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() jobs = await job_service.list_jobs( - user, + token, base_url=str(request.url_for("get_job_list")), phases=phase, after=after, @@ -137,13 +135,13 @@ async def get_job( ), ), ] = None, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() result_store = uws_factory.create_result_store() job = await job_service.get_summary( - user, job_id, signer=result_store, wait_seconds=wait, wait_phase=phase + token, job_id, signer=result_store, wait_seconds=wait, wait_phase=phase ) xml = job.to_xml(skip_empty=True) return Response(content=xml, media_type="application/xml") @@ -159,11 +157,11 @@ async def delete_job( job_id: str, *, request: Request, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - await job_service.delete(user, job_id) + await job_service.delete(token, job_id) return str(request.url_for("get_job_list")) @@ -187,11 +185,11 @@ async def delete_job_via_post( ), ] = None, request: Request, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - await job_service.delete(user, job_id) + await job_service.delete(token, job_id) return str(request.url_for("get_job_list")) @@ -203,11 +201,11 @@ async def delete_job_via_post( async def get_job_destruction( job_id: str, *, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - job = await job_service.get(user, job_id) + job = await job_service.get(token, job_id) return isodatetime(job.destruction_time) @@ -229,11 +227,11 @@ async def post_job_destruction( ), ], request: Request, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - await job_service.update_destruction(user, job_id, destruction) + await job_service.update_destruction(token, job_id, destruction) return str(request.url_for("get_job", job_id=job_id)) @@ -249,15 +247,15 @@ async def get_job_error( job_id: str, *, request: Request, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() - error = await job_service.get_error(user, job_id) - if not error: + errors = await job_service.get_error(token, job_id) + if not errors: raise DataMissingError(f"Job {job_id} did not fail") templates = uws_factory.create_templates() - return templates.error(request, error) + return templates.error(request, errors[0]) @uws_router.get( @@ -268,12 +266,15 @@ async def get_job_error( async def get_job_execution_duration( job_id: str, *, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - job = await job_service.get(user, job_id) - return str(int(job.execution_duration.total_seconds())) + job = await job_service.get(token, job_id) + if job.execution_duration: + return str(int(job.execution_duration.total_seconds())) + else: + return "0" @uws_router.post( @@ -291,16 +292,18 @@ async def post_job_execution_duration( title="New execution duration", description="Integer seconds of wall clock time.", examples=[14400], - ge=1, + ge=0, ), ], request: Request, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - duration = timedelta(seconds=executionduration) - await job_service.update_execution_duration(user, job_id, duration) + duration = None + if executionduration > 0: + duration = timedelta(seconds=executionduration) + await job_service.update_execution_duration(token, job_id, duration) return str(request.url_for("get_job", job_id=job_id)) @@ -312,11 +315,11 @@ async def post_job_execution_duration( async def get_job_owner( job_id: str, *, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - job = await job_service.get(user, job_id) + job = await job_service.get(token, job_id) return job.owner @@ -329,11 +332,11 @@ async def get_job_parameters( job_id: str, *, request: Request, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() - job = await job_service.get_summary(user, job_id) + job = await job_service.get_summary(token, job_id) if not job.parameters: raise DataMissingError(f"Job {job_id} has no parameters") xml = job.parameters.to_xml(skip_empty=True) @@ -348,11 +351,11 @@ async def get_job_parameters( async def get_job_phase( job_id: str, *, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - job = await job_service.get(user, job_id) + job = await job_service.get(token, job_id) return job.phase.value @@ -374,15 +377,14 @@ async def post_job_phase( ] = None, request: Request, user: Annotated[str, Depends(auth_dependency)], - access_token: Annotated[str, Depends(auth_delegated_token_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], - logger: Annotated[BoundLogger, Depends(auth_logger_dependency)], ) -> str: job_service = uws_factory.create_job_service() if phase == "ABORT": - await job_service.abort(user, job_id) + await job_service.abort(token, job_id) elif phase == "RUN": - await job_service.start(user, job_id, access_token) + await job_service.start(token, user, job_id) return str(request.url_for("get_job", job_id=job_id)) @@ -394,11 +396,11 @@ async def post_job_phase( async def get_job_quote( job_id: str, *, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - job = await job_service.get(user, job_id) + job = await job_service.get(token, job_id) if job.quote: return isodatetime(job.quote) else: @@ -422,12 +424,12 @@ async def get_job_results( job_id: str, *, request: Request, - user: Annotated[str, Depends(auth_dependency)], + token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> Response: job_service = uws_factory.create_job_service() result_store = uws_factory.create_result_store() - job = await job_service.get_summary(user, job_id, signer=result_store) + job = await job_service.get_summary(token, job_id, signer=result_store) if not job.results: raise DataMissingError(f"Job {job_id} has no results") xml = job.results.to_xml(skip_empty=True) @@ -483,19 +485,17 @@ async def create_job( phase: Annotated[ Literal["RUN"] | None, Depends(create_phase_dependency) ] = None, - parameters: Annotated[ - list[UWSJobParameter], Depends(route.dependency) - ], + parameters: Annotated[ParametersModel, Depends(route.dependency)], runid: Annotated[str | None, Depends(runid_post_dependency)], user: Annotated[str, Depends(auth_dependency)], token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], ) -> str: job_service = uws_factory.create_job_service() - job = await job_service.create(user, run_id=runid, params=parameters) + job = await job_service.create(token, parameters, run_id=runid) if phase == "RUN": - await job_service.start(user, job.job_id, token) - return str(request.url_for("get_job", job_id=job.job_id)) + await job_service.start(token, user, job.id) + return str(request.url_for("get_job", job_id=job.id)) def install_sync_post_handler(router: APIRouter, route: UWSRoute) -> None: @@ -519,7 +519,7 @@ def install_sync_post_handler(router: APIRouter, route: UWSRoute) -> None: async def post_sync( *, runid: Annotated[str | None, Depends(runid_post_dependency)], - params: Annotated[list[UWSJobParameter], Depends(route.dependency)], + parameters: Annotated[ParametersModel, Depends(route.dependency)], user: Annotated[str, Depends(auth_dependency)], token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], @@ -527,7 +527,7 @@ async def post_sync( job_service = uws_factory.create_job_service() result_store = uws_factory.create_result_store() result = await job_service.run_sync( - user, params, token=token, runid=runid + token, user, parameters, runid=runid ) return result_store.sign_url(result).url @@ -563,7 +563,7 @@ async def get_sync( ), ), ] = None, - params: Annotated[list[UWSJobParameter], Depends(route.dependency)], + parameters: Annotated[ParametersModel, Depends(route.dependency)], user: Annotated[str, Depends(auth_dependency)], token: Annotated[str, Depends(auth_delegated_token_dependency)], uws_factory: Annotated[UWSFactory, Depends(uws_dependency)], @@ -571,6 +571,6 @@ async def get_sync( job_service = uws_factory.create_job_service() result_store = uws_factory.create_result_store() result = await job_service.run_sync( - user, params, token=token, runid=runid + token, user, parameters, runid=runid ) return result_store.sign_url(result).url diff --git a/safir/src/safir/uws/_models.py b/safir/src/safir/uws/_models.py index c917ab04..f98fddc9 100644 --- a/safir/src/safir/uws/_models.py +++ b/safir/src/safir/uws/_models.py @@ -8,16 +8,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import asdict, dataclass -from datetime import datetime, timedelta -from enum import StrEnum -from typing import Generic, Self, TypeVar +from typing import Annotated, Any, Generic, Literal, Self, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel, BeforeValidator, Field, PlainSerializer from vo_models.uws import ( ErrorSummary, JobSummary, - Parameter, Parameters, ResultReference, Results, @@ -25,6 +21,9 @@ ) from vo_models.uws.types import ErrorType, ExecutionPhase, UWSVersion +from safir.arq.uws import WorkerResult +from safir.pydantic import SecondsTimedelta, UtcDatetime + P = TypeVar("P", bound="ParametersModel") """Generic type for the parameters model.""" @@ -39,14 +38,18 @@ __all__ = [ "ACTIVE_PHASES", - "ErrorCode", + "Job", + "JobCreate", + "JobError", + "JobResult", + "JobUpdateAborted", + "JobUpdateCompleted", + "JobUpdateError", + "JobUpdateExecuting", + "JobUpdateMetadata", + "JobUpdateQueued", "ParametersModel", - "UWSJob", - "UWSJobDescription", - "UWSJobError", - "UWSJobParameter", - "UWSJobResult", - "UWSJobResultSigned", + "SignedJobResult", ] @@ -58,49 +61,9 @@ """Phases in which the job is active and can be waited on.""" -class ErrorCode(StrEnum): - """Possible error codes in ``text/plain`` responses. - - The following list of errors is taken from the SODA specification and - therefore may not be appropriate for all DALI services. - """ - - AUTHENTICATION_ERROR = "AuthenticationError" - AUTHORIZATION_ERROR = "AuthorizationError" - MULTIVALUED_PARAM_NOT_SUPPORTED = "MultiValuedParamNotSupported" - ERROR = "Error" - SERVICE_UNAVAILABLE = "ServiceUnavailable" - USAGE_ERROR = "UsageError" - - class ParametersModel(BaseModel, ABC, Generic[W, X]): """Defines the interface for a model suitable for job parameters.""" - @classmethod - @abstractmethod - def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self: - """Validate generic UWS parameters and convert to the internal model. - - Parameters - ---------- - params - Generic input job parameters. - - Returns - ------- - ParametersModel - Parsed cutout parameters specific to service. - - Raises - ------ - safir.uws.MultiValuedParameterError - Raised if multiple parameters are provided but not supported. - safir.uws.ParameterError - Raised if one of the parameters could not be parsed. - pydantic.ValidationError - Raised if the parameters do not validate. - """ - @abstractmethod def to_worker_parameters(self) -> W: """Convert to the domain model used by the backend worker.""" @@ -110,74 +73,121 @@ def to_xml_model(self) -> X: """Convert to the XML model used in XML API responses.""" -@dataclass -class UWSJobError: +class JobError(BaseModel): """Failure information about a job.""" - error_type: ErrorType - """Type of the error.""" - - error_code: ErrorCode - """The SODA error code of this error.""" - - message: str - """Brief error message. - - Note that the UWS specification allows a sequence of messages, but we only - use a single message and thus a sequence of length one. - """ - - detail: str | None = None - """Extended error message with additional detail.""" - - def to_dict(self) -> dict[str, str | None]: - """Convert to a dictionary, primarily for logging.""" - return asdict(self) + type: Annotated[ + ErrorType, + Field( + title="Error type", + description="Type of the error", + examples=[ErrorType.TRANSIENT, ErrorType.FATAL], + ), + ] + + code: Annotated[ + str, + Field( + title="Error code", + description="Code for this class of error", + examples=["ServiceUnavailable"], + ), + ] + + message: Annotated[ + str, + Field( + title="Error message", + description="Brief error messages", + examples=["Short error message"], + ), + ] + + detail: Annotated[ + str | None, + Field( + title="Extended error message", + description="Extended error message with additional detail", + examples=["Some longer error message with details", None], + ), + ] = None def to_xml_model(self) -> ErrorSummary: """Convert to a Pydantic XML model.""" return ErrorSummary( - message=f"{self.error_code.value}: {self.message}", - type=self.error_type, + message=f"{self.code}: {self.message}", + type=self.type, has_detail=self.detail is not None, ) -@dataclass -class UWSJobResult: - """A single result from the job.""" +class JobResult(BaseModel): + """A single result from a job.""" + + id: Annotated[ + str, + Field( + title="Result ID", + description="Identifier for this result", + examples=["image", "metadata"], + ), + ] + + url: Annotated[ + str, + Field( + title="Result URL", + description="URL where the result is stored", + examples=["s3://service-result-bucket/some-file"], + ), + ] + + size: Annotated[ + int | None, + Field( + title="Size of result", + description="Size of the result in bytes if known", + examples=[1238123, None], + ), + ] = None + + mime_type: Annotated[ + str | None, + Field( + title="MIME type of result", + description="MIME type of the result if known", + examples=["application/fits", "application/x-votable+xml", None], + ), + ] = None - result_id: str - """Identifier for the result.""" - - url: str - """The URL for the result, which must point into a GCS bucket.""" - - size: int | None = None - """Size of the result in bytes.""" - - mime_type: str | None = None - """MIME type of the result.""" + @classmethod + def from_worker_result(cls, result: WorkerResult) -> Self: + """Convert from the `~safir.arq.uws.WorkerResult` model.""" + return cls( + id=result.result_id, + url=result.url, + size=result.size, + mime_type=result.mime_type, + ) def to_xml_model(self) -> ResultReference: """Convert to a Pydantic XML model.""" return ResultReference( - id=self.result_id, size=self.size, mime_type=self.mime_type + id=self.id, size=self.size, mime_type=self.mime_type ) -@dataclass -class UWSJobResultSigned(UWSJobResult): +class SignedJobResult(JobResult): """A single result from the job with a signed URL. - A `UWSJobResult` is converted to a `UWSJobResultSigned` before generating - the response via templating or returning the URL as a redirect. + A `JobResult` is converted to a `SignedJobResult` before generating the + response via templating or returning the URL as a redirect. """ def to_xml_model(self) -> ResultReference: """Convert to a Pydantic XML model.""" return ResultReference( - id=self.result_id, + id=self.id, type=None, href=self.url, size=self.size, @@ -185,59 +195,195 @@ def to_xml_model(self) -> ResultReference: ) -@dataclass -class UWSJobParameter: - """An input parameter to the job.""" - - parameter_id: str - """Identifier of the parameter.""" - - value: str - """Value of the parameter.""" - - def to_dict(self) -> dict[str, str | bool]: - """Convert to a dictionary, primarily for logging.""" - return asdict(self) - - def to_xml_model(self) -> Parameter: - """Convert to a Pydantic XML model.""" - return Parameter(id=self.parameter_id, value=self.value) - - -@dataclass -class UWSJobDescription: - """Brief job description used for the job list. - - This is a strict subset of the fields of `UWSJob`, but is kept separate - without an inheritance relationship to reflect how it's used in code. - """ - - job_id: str - """Unique identifier of the job.""" - - message_id: str | None - """Internal message identifier for the work queuing system.""" - - owner: str - """Identity of the owner of the job.""" - - phase: ExecutionPhase - """Execution phase of the job.""" - - run_id: str | None - """Optional opaque string provided by the client. - - The RunId is intended for the client to add a unique identifier to all - jobs that are part of a single operation from the perspective of the - client. This may aid in tracing issues through a complex system or - identifying which operation a job is part of. - """ - - creation_time: datetime - """When the job was created.""" - - def to_xml_model(self, base_url: str) -> ShortJobDescription: - """Convert to a Pydantic XML model. +class JobBase(BaseModel): + """Fields common to all variations of the job record.""" + + json_parameters: Annotated[ + dict[str, Any], + Field( + title="Job parameters", + description=( + "May be any JSON-serialized object. Stored opaquely and" + " returned as part of the job record." + ), + examples=[ + { + "ids": ["data-id"], + "stencils": [ + { + "type": "circle", + "center": [1.1, 2.1], + "radius": 0.001, + } + ], + }, + ], + ), + ] + + run_id: Annotated[ + str | None, + Field( + title="Client-provided run ID", + description=( + "The run ID allows the client to add a unique identifier to" + " all jobs that are part of a single operation, which may aid" + " in tracing issues through a complex system or identifying" + " which operation a job is part of" + ), + examples=["daily-2024-10-29"], + ), + ] = None + + destruction_time: Annotated[ + UtcDatetime, + Field( + title="Destruction time", + description=( + "At this time, the job will be aborted if it is still" + " running, its results will be deleted, and it will either" + " change phase to ARCHIVED or all record of the job will be" + " discarded" + ), + examples=["2024-11-29T23:57:55+00:00"], + ), + ] + + execution_duration: Annotated[ + SecondsTimedelta | None, + Field( + title="Maximum execution duration", + description=( + "Allowed maximum execution duration. This is specified in" + " elapsed wall clock time (not CPU time). If null, the" + " execution time is unlimited. If the job runs for longer than" + " this time period, it will be aborted." + ), + ), + PlainSerializer( + lambda t: int(t.total_seconds()) if t is not None else None, + return_type=int, + ), + ] = None + + +class JobCreate(JobBase): + """Information required to create a new UWS job (Wobbly format).""" + + +class SerializedJob(JobBase): + """A single UWS job (Wobbly format).""" + + id: Annotated[ + str, + Field( + title="Job ID", + description="Unique identifier of the job", + examples=["47183"], + ), + BeforeValidator(lambda v: str(v) if isinstance(v, int) else v), + ] + + service: Annotated[ + str, + Field( + title="Service", + description="Service responsible for this job", + examples=["vo-cutouts"], + ), + ] + + owner: Annotated[ + str, + Field( + title="Job owner", + description="Identity of the owner of the job", + examples=["someuser"], + ), + ] + + phase: Annotated[ + ExecutionPhase, + Field( + title="Execution phase", + description="Current execution phase of the job", + examples=[ + ExecutionPhase.PENDING, + ExecutionPhase.EXECUTING, + ExecutionPhase.COMPLETED, + ], + ), + ] + + message_id: Annotated[ + str | None, + Field( + title="Work queue message ID", + description=( + "Internal message identifier for the work queuing system." + " Only meaningful to the service that stored this ID." + ), + examples=["e621a175-e3bf-4a61-98d7-483cb5fb9ec2"], + ), + ] = None + + creation_time: Annotated[ + UtcDatetime, + Field( + title="Creation time", + description="When the job was created", + examples=["2024-10-29T23:57:55+00:00"], + ), + ] + + start_time: Annotated[ + UtcDatetime | None, + Field( + title="Start time", + description="When the job started executing (if it has)", + examples=["2024-10-30T00:00:21+00:00", None], + ), + ] = None + + end_time: Annotated[ + UtcDatetime | None, + Field( + title="End time", + description="When the job stopped executing (if it has)", + examples=["2024-10-30T00:08:45+00:00", None], + ), + ] = None + + quote: Annotated[ + UtcDatetime | None, + Field( + title="Expected completion time", + description=( + "Expected completion time of the job if it were started now," + " or null to indicate that the expected duration is not known." + " If later than the destruction time, indicates that the job" + " is not possible due to resource constraints." + ), + ), + ] = None + + errors: Annotated[ + list[JobError], + Field( + title="Error", description="Error information if the job failed" + ), + ] = [] + + results: Annotated[ + list[JobResult], + Field( + title="Job results", + description="Results of the job, if it has finished", + ), + ] = [] + + def to_job_description(self, base_url: str) -> ShortJobDescription: + """Convert to the Pydantic XML model used for the summary of jobs. Parameters ---------- @@ -249,90 +395,67 @@ def to_xml_model(self, base_url: str) -> ShortJobDescription: run_id=self.run_id, creation_time=self.creation_time, owner_id=self.owner, - job_id=self.job_id, + job_id=self.id, type=None, - href=f"{base_url}/{self.job_id}", + href=f"{base_url}/{self.id}", ) -@dataclass -class UWSJob: - """Represents a single UWS job.""" - - job_id: str - """Unique identifier of the job.""" - - message_id: str | None - """Internal message identifier for the work queuing system.""" +class Job(SerializedJob, Generic[P]): + """A single UWS job with deserialized parameters.""" - owner: str - """Identity of the owner of the job.""" + parameters: Annotated[ + P, + Field( + title="Job parameters", + description=( + "Job parameters converted to their Pydantic model. Use" + " ``json_parameters`` for the serialized form sent over" + " the wire." + ), + exclude=True, + ), + ] - phase: ExecutionPhase - """Execution phase of the job.""" - - run_id: str | None - """Optional opaque string provided by the client. - - The RunId is intended for the client to add a unique identifier to all - jobs that are part of a single operation from the perspective of the - client. This may aid in tracing issues through a complex system or - identifying which operation a job is part of. - """ - - creation_time: datetime - """When the job was created.""" - - start_time: datetime | None - """When the job started executing (if it has started).""" - - end_time: datetime | None - """When the job stopped executing (if it has stopped).""" - - destruction_time: datetime - """Time at which the job should be destroyed. - - At this time, the job will be aborted if it is still running, its results - will be deleted, and all record of the job will be discarded. - - This field is optional in the UWS standard, but in this UWS implementation - all jobs will have a destruction time, so it is not marked as optional. - """ - - execution_duration: timedelta - """Allowed maximum execution duration in seconds. - - This is specified in elapsed wall clock time, or 0 for unlimited execution - time. If the job runs for longer than this time period, it will be - aborted. - """ - - quote: datetime | None - """Expected completion time of the job if it were started now. - - May be `None` to indicate that the expected duration of the job is not - known. Maybe later than the destruction time to indicate that the job is - not possible due to resource constraints. - """ - - error: UWSJobError | None - """Error information if the job failed.""" + @classmethod + def from_serialized_job( + cls, job: SerializedJob, parameters_type: type[P] + ) -> Self: + """Convert from a serialized job returned by Wobbly. - parameters: list[UWSJobParameter] - """The parameters of the job.""" + Parameters + ---------- + job + Serialized job from Wobbly. + parameters_type + Model to use for job parameters. - results: list[UWSJobResult] - """The results of the job.""" + Returns + ------- + Job + Job with the correct parameters model. - def to_xml_model( - self, parameters_type: type[P], job_summary_type: type[S] - ) -> S: + Raises + ------ + pydantic.ValidationError + Raised if the serialized parameters cannot be validated. + """ + job_dict = job.model_dump() + params = job_dict.get("json_parameters") + if params: + job_dict["parameters"] = parameters_type.model_validate(params) + return cls.model_validate(job_dict) + + def to_serialized_job(self) -> SerializedJob: + """Convert to a serialized job suitable for sending to Wobbly.""" + job = self.model_dump(mode="json") + return SerializedJob.model_validate(job) + + def to_xml_model(self, job_summary_type: type[S]) -> S: """Convert to a Pydantic XML model. Parameters ---------- - parameters_type - Model class used for the job parameters. job_summary_type XML model class for the job summary. @@ -344,9 +467,14 @@ def to_xml_model( results = None if self.results: results = Results(results=[r.to_xml_model() for r in self.results]) - parameters = parameters_type.from_job_parameters(self.parameters) + duration = None + if self.execution_duration: + duration = int(self.execution_duration.total_seconds()) + error_summary = None + if self.errors: + error_summary = self.errors[0].to_xml_model() return job_summary_type( - job_id=self.job_id, + job_id=self.id, run_id=self.run_id, owner_id=self.owner, phase=self.phase, @@ -354,10 +482,149 @@ def to_xml_model( creation_time=self.creation_time, start_time=self.start_time, end_time=self.end_time, - execution_duration=int(self.execution_duration.total_seconds()), + execution_duration=duration, destruction=self.destruction_time, - parameters=parameters.to_xml_model(), + parameters=self.parameters.to_xml_model(), results=results, - error_summary=self.error.to_xml_model() if self.error else None, + error_summary=error_summary, version=UWSVersion.V1_1, ) + + +class JobUpdateAborted(BaseModel): + """Input model when aborting a job.""" + + phase: Annotated[ + Literal[ExecutionPhase.ABORTED], + Field( + title="New phase", + description="New phase of job", + examples=[ExecutionPhase.ABORTED], + ), + ] + + +class JobUpdateCompleted(BaseModel): + """Input model when marking a job as complete.""" + + phase: Annotated[ + Literal[ExecutionPhase.COMPLETED], + Field( + title="New phase", + description="New phase of job", + examples=[ExecutionPhase.COMPLETED], + ), + ] + + results: Annotated[ + list[JobResult], + Field(title="Job results", description="All the results of the job"), + ] + + +class JobUpdateExecuting(BaseModel): + """Input model when marking a job as executing.""" + + phase: Annotated[ + Literal[ExecutionPhase.EXECUTING], + Field( + title="New phase", + description="New phase of job", + examples=[ExecutionPhase.EXECUTING], + ), + ] + + start_time: Annotated[ + UtcDatetime, + Field( + title="Start time", + description="When the job started executing", + examples=["2024-11-01T12:15:45+00:00"], + ), + ] + + +class JobUpdateError(BaseModel): + """Input model when marking a job as failed.""" + + phase: Annotated[ + Literal[ExecutionPhase.ERROR], + Field( + title="New phase", + description="New phase of job", + examples=[ExecutionPhase.ERROR], + ), + ] + + errors: Annotated[ + list[JobError], + Field( + title="Failure details", + description="Job failure error message and details", + min_length=1, + ), + ] + + +class JobUpdateQueued(BaseModel): + """Input model when marking a job as queued.""" + + phase: Annotated[ + Literal[ExecutionPhase.QUEUED], + Field( + title="New phase", + description="New phase of job", + examples=[ExecutionPhase.QUEUED], + ), + ] + + message_id: Annotated[ + str | None, + Field( + title="Queue message ID", + description="Corresponding message within a job queuing system", + examples=["4ce850a7-d877-4827-a3f6-f84534ec3fad"], + ), + ] + + +class JobUpdateMetadata(BaseModel): + """Input model when updating job metadata.""" + + phase: Annotated[ + None, + Field( + title="New phase", description="New phase of job", examples=[None] + ), + ] = None + + destruction_time: Annotated[ + UtcDatetime, + Field( + title="Destruction time", + description=( + "At this time, the job will be aborted if it is still" + " running, its results will be deleted, and it will either" + " change phase to ARCHIVED or all record of the job will be" + " discarded" + ), + examples=["2024-11-29T23:57:55+00:00"], + ), + ] + + execution_duration: Annotated[ + SecondsTimedelta | None, + Field( + title="Maximum execution duration", + description=( + "Allowed maximum execution duration. This is specified in" + " elapsed wall clock time (not CPU time). If null, the" + " execution time is unlimited. If the job runs for longer than" + " this time period, it will be aborted." + ), + ), + PlainSerializer( + lambda t: int(t.total_seconds()) if t else None, + return_type=int | None, + ), + ] diff --git a/safir/src/safir/uws/_responses.py b/safir/src/safir/uws/_responses.py index afc00427..9dc64c36 100644 --- a/safir/src/safir/uws/_responses.py +++ b/safir/src/safir/uws/_responses.py @@ -8,8 +8,7 @@ from safir.datetime import isodatetime -from ._models import UWSJobError -from ._results import ResultStore +from ._models import JobError __all__ = ["UWSTemplates"] @@ -28,10 +27,7 @@ class UWSTemplates: This also includes VOSI-Availability since it was convenient to provide. """ - def __init__(self, result_store: ResultStore) -> None: - self._result_store = result_store - - def error(self, request: Request, error: UWSJobError) -> Response: + def error(self, request: Request, error: JobError) -> Response: """Return the error of a job as an XML response.""" return _templates.TemplateResponse( request, diff --git a/safir/src/safir/uws/_results.py b/safir/src/safir/uws/_results.py index 78a9e601..611e93b8 100644 --- a/safir/src/safir/uws/_results.py +++ b/safir/src/safir/uws/_results.py @@ -10,7 +10,7 @@ from safir.gcs import SignedURLService from ._config import UWSConfig -from ._models import UWSJobResult, UWSJobResultSigned +from ._models import JobResult, SignedJobResult __all__ = ["ResultStore"] @@ -31,7 +31,7 @@ def __init__(self, config: UWSConfig) -> None: lifetime=config.url_lifetime, ) - def sign_url(self, result: UWSJobResult) -> UWSJobResultSigned: + def sign_url(self, result: JobResult) -> SignedJobResult: """Convert a job result into a signed URL. Parameters @@ -41,7 +41,7 @@ def sign_url(self, result: UWSJobResult) -> UWSJobResultSigned: Returns ------- - UWSJobResultSigned + SignedJobResult Result with any GCS URL replaced with a signed URL. Notes @@ -63,8 +63,8 @@ def sign_url(self, result: UWSJobResult) -> UWSJobResultSigned: else: mime_type = result.mime_type signed_url = self._url_service.signed_url(result.url, mime_type) - return UWSJobResultSigned( - result_id=result.result_id, + return SignedJobResult( + id=result.id, url=signed_url, size=result.size, mime_type=result.mime_type, diff --git a/safir/src/safir/uws/_schema.py b/safir/src/safir/uws/_schema.py deleted file mode 100644 index 3919d378..00000000 --- a/safir/src/safir/uws/_schema.py +++ /dev/null @@ -1,101 +0,0 @@ -"""SQLAlchemy schema for the UWS database.""" - -from __future__ import annotations - -from datetime import datetime - -from sqlalchemy import ForeignKey, Index, String, Text -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from vo_models.uws.types import ErrorType, ExecutionPhase - -from ._models import ErrorCode - -__all__ = [ - "Job", - "JobParameter", - "JobResult", - "UWSSchemaBase", -] - - -class UWSSchemaBase(DeclarativeBase): - """SQLAlchemy declarative base for the UWS database schema.""" - - -class JobResult(UWSSchemaBase): - """Table holding job results.""" - - __tablename__ = "job_result" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - job_id: Mapped[int] = mapped_column( - ForeignKey("job.id", ondelete="CASCADE") - ) - result_id: Mapped[str] = mapped_column(String(64)) - sequence: Mapped[int] - url: Mapped[str] = mapped_column(String(256)) - size: Mapped[int | None] - mime_type: Mapped[str | None] = mapped_column(String(64)) - - __table_args__ = ( - Index("by_sequence", "job_id", "sequence", unique=True), - Index("by_result_id", "job_id", "result_id", unique=True), - ) - - -class JobParameter(UWSSchemaBase): - """Table holding parameters to UWS jobs.""" - - __tablename__ = "job_parameter" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - job_id: Mapped[int] = mapped_column( - ForeignKey("job.id", ondelete="CASCADE") - ) - parameter: Mapped[str] = mapped_column(String(64)) - value: Mapped[str] = mapped_column(Text) - is_post: Mapped[bool] = mapped_column(default=False) - - __table_args__ = (Index("by_parameter", "job_id", "parameter"),) - - -class Job(UWSSchemaBase): - """Table holding UWS jobs. - - Notes - ----- - The details of how the relationships are defined are chosen to allow this - schema to be used with async SQLAlchemy. Review the SQLAlchemy asyncio - documentation carefully before making changes. There are a lot of - surprises and sharp edges. - """ - - __tablename__ = "job" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - message_id: Mapped[str | None] = mapped_column(String(64)) - owner: Mapped[str] = mapped_column(String(64)) - phase: Mapped[ExecutionPhase] - run_id: Mapped[str | None] = mapped_column(String(64)) - creation_time: Mapped[datetime] - start_time: Mapped[datetime | None] - end_time: Mapped[datetime | None] - destruction_time: Mapped[datetime] - execution_duration: Mapped[int] - quote: Mapped[datetime | None] - error_type: Mapped[ErrorType | None] - error_code: Mapped[ErrorCode | None] - error_message: Mapped[str | None] = mapped_column(Text) - error_detail: Mapped[str | None] = mapped_column(Text) - - parameters: Mapped[list[JobParameter]] = relationship( - cascade="delete", lazy="selectin", uselist=True - ) - results: Mapped[list[JobResult]] = relationship( - cascade="delete", lazy="selectin", uselist=True - ) - - __table_args__ = ( - Index("by_owner_phase", "owner", "phase", "creation_time"), - Index("by_owner_time", "owner", "creation_time"), - ) diff --git a/safir/src/safir/uws/_service.py b/safir/src/safir/uws/_service.py index 96dd86d5..1a0b2076 100644 --- a/safir/src/safir/uws/_service.py +++ b/safir/src/safir/uws/_service.py @@ -19,18 +19,17 @@ from ._constants import JOB_STOP_TIMEOUT from ._exceptions import ( InvalidPhaseError, - PermissionDeniedError, SyncJobFailedError, SyncJobNoResultsError, SyncJobTimeoutError, ) from ._models import ( ACTIVE_PHASES, + Job, + JobError, + JobResult, + JobUpdateMetadata, ParametersModel, - UWSJob, - UWSJobError, - UWSJobParameter, - UWSJobResult, ) from ._results import ResultStore from ._storage import JobStore @@ -71,7 +70,7 @@ def __init__( self._storage = storage self._logger = logger - async def abort(self, user: str, job_id: str) -> None: + async def abort(self, token: str, job_id: str) -> None: """Abort a queued or running job. If the job is already in a completed state, this operation does @@ -79,22 +78,20 @@ async def abort(self, user: str, job_id: str) -> None: Parameters ---------- - user - User on behalf of whom this operation is performed. + token + Delegated token for user. job_id Identifier of the job. Raises ------ - PermissionDeniedError - If the job ID doesn't exist or is for a user other than the - provided user. + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - job = await self._storage.get(job_id) - if job.owner != user: - raise PermissionDeniedError(f"Access to job {job_id} denied") - params_model = self._validate_parameters(job.parameters) - logger = self._build_logger_for_job(job, params_model) + job = await self._storage.get(token, job_id) + logger = self._build_logger_for_job(job) if job.phase not in ACTIVE_PHASES: logger.info(f"Cannot stop job in phase {job.phase.value}") return @@ -102,29 +99,24 @@ async def abort(self, user: str, job_id: str) -> None: timeout = JOB_STOP_TIMEOUT.total_seconds() logger.info("Aborting queued job", arq_job_id=job.message_id) await self._arq.abort_job(job.message_id, timeout=timeout) - await self._storage.mark_aborted(job_id) + await self._storage.mark_aborted(token, job_id) logger.info("Aborted job") async def availability(self) -> Availability: - """Check whether the service is up. + """Check the availability of underlying services. - Used for ``/availability`` endpoints. Currently this only checks the - database. - - Returns - ------- - Availability - Service availability information. + Currently, this does nothing. Eventually, it may do a health check of + Wobbly. """ - return await self._storage.availability() + return Availability(available=True) async def create( self, - user: str, - params: list[UWSJobParameter], + token: str, + parameters: ParametersModel, *, run_id: str | None = None, - ) -> UWSJob: + ) -> Job: """Create a pending job. This does not start execution of the job. That must be done separately @@ -132,84 +124,72 @@ async def create( Parameters ---------- - user - User on behalf this operation is performed. + token + Delegated token for user. + parameters + The input parameters to the job. run_id A client-supplied opaque identifier to record with the job. - params - The input parameters to the job. Returns ------- - JobSummary + Job Information about the newly-created job. + + Raises + ------ + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - params_model = self._validate_parameters(params) - job = await self._storage.add( - owner=user, + job = await self._storage.create( + token, run_id=run_id, - params=params, + parameters=parameters, execution_duration=self._config.execution_duration, lifetime=self._config.lifetime, ) - logger = self._build_logger_for_job(job, params_model) + logger = self._build_logger_for_job(job) logger.info("Created job") return job - async def delete(self, user: str, job_id: str) -> None: + async def delete(self, token: str, job_id: str) -> None: """Delete a job. If the job is in an active phase, cancel it before deleting it. Parameters ---------- - user - Owner of job. + token + Delegated token for user. job_id Identifier of job. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - job = await self._storage.get(job_id) - if job.owner != user: - raise PermissionDeniedError(f"Access to job {job_id} denied") - logger = self._logger.bind(user=user, job_id=job_id) + job = await self._storage.get(token, job_id) + logger = self._build_logger_for_job(job) if job.phase in ACTIVE_PHASES and job.message_id: try: await self._arq.abort_job(job.message_id) except Exception as e: logger.warning("Unable to abort job", error=str(e)) - await self._storage.delete(job_id) + await self._storage.delete(token, job_id) logger.info("Deleted job") - async def delete_expired(self) -> None: - """Delete all expired jobs. - - A job is expired if it has passed its destruction time. If the job is - in an active phase, cancel it before deleting it. - """ - jobs = await self._storage.list_expired() - if jobs: - self._logger.info(f"Deleting {len(jobs)} expired jobs") - for job in jobs: - if job.phase in ACTIVE_PHASES and job.message_id: - try: - await self._arq.abort_job(job.message_id) - except Exception as e: - self._logger.warning( - "Unable to abort expired job", error=str(e) - ) - await self._storage.delete(job.job_id) - self._logger.info("Deleted expired job") - self._logger.info(f"Finished deleting {len(jobs)} expired jobs") - async def get( self, - user: str, + token: str, job_id: str, *, wait_seconds: int | None = None, wait_phase: ExecutionPhase | None = None, wait_for_completion: bool = False, - ) -> UWSJob: + ) -> Job: """Retrieve a job. This also supports long-polling, to implement UWS 1.1 blocking @@ -218,8 +198,8 @@ async def get( Parameters ---------- - user - User on behalf this operation is performed. + token + Delegated token for user. job_id Identifier of the job. wait_seconds @@ -238,14 +218,15 @@ async def get( Returns ------- - UWSJob + Job Corresponding job. Raises ------ - PermissionDeniedError - If the job ID doesn't exist or is for a user other than the - provided user. + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. Notes ----- @@ -254,9 +235,7 @@ async def get( 0.1s delay and increasing by 1.5x). This may need to be reconsidered if it becomes a performance bottleneck. """ - job = await self._storage.get(job_id) - if job.owner != user: - raise PermissionDeniedError(f"Access to job {job_id} denied") + job = await self._storage.get(token, job_id) # If waiting for a status change was requested and is meaningful, do # so, capping the wait time at the configured maximum timeout. @@ -270,39 +249,40 @@ async def get( until_not = ACTIVE_PHASES else: until_not = {wait_phase} if wait_phase else {job.phase} - job = await self._wait_for_job(job, until_not, wait) + job = await self._wait_for_job(token, job, until_not, timeout=wait) return job - async def get_error(self, user: str, job_id: str) -> UWSJobError | None: - """Get the error for a job, if any. + async def get_error( + self, token: str, job_id: str + ) -> list[JobError] | None: + """Get the errors for a job, if any. Parameters ---------- - user - User on behalf this operation is performed. + token + Delegated token for user. job_id Identifier of the job. Returns ------- - UWSJobError or None + list of JobError or None Error information for the job, or `None` if the job didn't fail. Raises ------ - PermissionDeniedError - If the job ID doesn't exist or is for a user other than the - provided user. + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - job = await self._storage.get(job_id) - if job.owner != user: - raise PermissionDeniedError(f"Access to job {job_id} denied") - return job.error + job = await self._storage.get(token, job_id) + return job.errors async def get_summary( self, - user: str, + token: str, job_id: str, *, signer: ResultStore | None = None, @@ -316,8 +296,8 @@ async def get_summary( Parameters ---------- - user - User on behalf this operation is performed. + token + Delegated token for user. job_id Identifier of the job. signer @@ -341,9 +321,10 @@ async def get_summary( Raises ------ - PermissionDeniedError - If the job ID doesn't exist or is for a user other than the - provided user. + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. Notes ----- @@ -353,17 +334,15 @@ async def get_summary( if it becomes a performance bottleneck. """ job = await self.get( - user, job_id, wait_seconds=wait_seconds, wait_phase=wait_phase + token, job_id, wait_seconds=wait_seconds, wait_phase=wait_phase ) if signer: job.results = [signer.sign_url(r) for r in job.results] - return job.to_xml_model( - self._config.parameters_type, self._config.job_summary_type - ) + return job.to_xml_model(self._config.job_summary_type) async def list_jobs( self, - user: str, + token: str, base_url: str, *, phases: list[ExecutionPhase] | None = None, @@ -374,8 +353,8 @@ async def list_jobs( Parameters ---------- - user - Name of the user whose jobs to load. + token + Delegated token for user. base_url Base URL used to form URLs to the specific jobs. phases @@ -390,36 +369,43 @@ async def list_jobs( ------- Jobs Collection of short job descriptions. + + Raises + ------ + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ jobs = await self._storage.list_jobs( - user, phases=phases, after=after, count=count + token, phases=phases, after=after, count=count ) - return Jobs(jobref=[j.to_xml_model(base_url) for j in jobs]) + return Jobs(jobref=[j.to_job_description(base_url) for j in jobs]) async def run_sync( self, + token: str, user: str, - params: list[UWSJobParameter], + parameters: ParametersModel, *, - token: str, runid: str | None, - ) -> UWSJobResult: + ) -> JobResult: """Create a job for a sync request and return the first result. Parameters ---------- + token + Delegated token for user. + user + User on behalf of whom this operation is performed. params Job parameters. user Username of user running the job. - token - Delegated Gafaelfawr token to pass to the backend worker. runid User-supplied RunID, if any. Returns ------- - result + JobResult First result of the successfully-executed job. Raises @@ -430,17 +416,18 @@ async def run_sync( Raised if the job returned no results. SyncJobTimeoutError Raised if the job execution timed out. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - job = await self.create(user, params, run_id=runid) - params_model = self._validate_parameters(params) - logger = self._build_logger_for_job(job, params_model) + job = await self.create(token, parameters, run_id=runid) + logger = self._build_logger_for_job(job) # Start the job and wait for it to complete. - metadata = await self.start(user, job.job_id, token) + metadata = await self.start(token, user, job.id) logger = logger.bind(arq_job_id=metadata.id) job = await self.get( - user, - job.job_id, + token, + job.id, wait_seconds=int(self._config.sync_timeout.total_seconds()), wait_for_completion=True, ) @@ -449,9 +436,11 @@ async def run_sync( if job.phase not in (ExecutionPhase.COMPLETED, ExecutionPhase.ERROR): logger.warning("Job timed out", timeout=self._config.sync_timeout) raise SyncJobTimeoutError(self._config.sync_timeout) - if job.error: - logger.warning("Job failed", error=job.error.to_dict()) - raise SyncJobFailedError(job.error) + if job.errors: + # Only one error is supported for right now. + error = job.errors[0] + logger.warning("Job failed", error=error.model_dump(mode="json")) + raise SyncJobFailedError(error) if not job.results: logger.warning("Job returned no results") raise SyncJobNoResultsError @@ -459,18 +448,18 @@ async def run_sync( # Return the first result. return job.results[0] - async def start(self, user: str, job_id: str, token: str) -> JobMetadata: + async def start(self, token: str, user: str, job_id: str) -> JobMetadata: """Start execution of a job. Parameters ---------- + token + Gafaelfawr token used to authenticate to services used by the + backend on the user's behalf. user User on behalf of whom this operation is performed. job_id Identifier of the job to start. - token - Gafaelfawr token used to authenticate to services used by the - backend on the user's behalf. Returns ------- @@ -479,39 +468,37 @@ async def start(self, user: str, job_id: str, token: str) -> JobMetadata: Raises ------ - safir.uws._exceptions.PermissionDeniedError - If the job ID doesn't exist or is for a user other than the - provided user. + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - job = await self._storage.get(job_id) - if job.owner != user: - raise PermissionDeniedError(f"Access to job {job_id} denied") + job = await self._storage.get(token, job_id) if job.phase not in (ExecutionPhase.PENDING, ExecutionPhase.HELD): raise InvalidPhaseError(f"Cannot start job in phase {job.phase}") - params_model = self._validate_parameters(job.parameters) - logger = self._build_logger_for_job(job, params_model) + logger = self._build_logger_for_job(job) info = WorkerJobInfo( - job_id=job.job_id, + job_id=job.id, user=user, token=token, - timeout=job.execution_duration, + timeout=job.execution_duration or self._config.lifetime, run_id=job.run_id, ) - params = params_model.to_worker_parameters().model_dump(mode="json") + params = job.parameters.to_worker_parameters().model_dump(mode="json") metadata = await self._arq.enqueue(self._config.worker, params, info) - await self._storage.mark_queued(job_id, metadata) + await self._storage.mark_queued(token, job_id, metadata) logger.info("Started job", arq_job_id=metadata.id) return metadata async def update_destruction( - self, user: str, job_id: str, destruction: datetime + self, token: str, job_id: str, destruction: datetime ) -> datetime | None: """Update the destruction time of a job. Parameters ---------- - user - User on behalf of whom this operation is performed + token + Delegated token for user. job_id Identifier of the job to update. destruction @@ -527,13 +514,13 @@ async def update_destruction( Raises ------ - safir.uws._exceptions.PermissionDeniedError - If the job ID doesn't exist or is for a user other than the - provided user. + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - job = await self._storage.get(job_id) - if job.owner != user: - raise PermissionDeniedError(f"Access to job {job_id} denied") + job = await self._storage.get(token, job_id) + logger = self._build_logger_for_job(job) # Validate the new value. if validator := self._config.validate_destruction: @@ -544,24 +531,26 @@ async def update_destruction( # Update the destruction time if needed. if destruction == job.destruction_time: return None - await self._storage.update_destruction(job_id, destruction) - self._logger.info( + metadata = JobUpdateMetadata( + destruction_time=destruction, + execution_duration=job.execution_duration, + ) + await self._storage.update_metadata(token, job_id, metadata) + logger.info( "Changed job destruction time", - user=user, - job_id=job_id, destruction=isodatetime(destruction), ) return destruction async def update_execution_duration( - self, user: str, job_id: str, duration: timedelta + self, token: str, job_id: str, duration: timedelta | None ) -> timedelta | None: """Update the execution duration time of a job. Parameters ---------- - user - User on behalf of whom this operation is performed + token + Delegated token for user. job_id Identifier of the job to update. duration @@ -577,93 +566,69 @@ async def update_execution_duration( Raises ------ - safir.uws._exceptions.PermissionDeniedError - If the job ID doesn't exist or is for a user other than the - provided user. + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - job = await self._storage.get(job_id) - if job.owner != user: - raise PermissionDeniedError(f"Access to job {job_id} denied") + job = await self._storage.get(token, job_id) + logger = self._build_logger_for_job(job) # Validate the new value. if validator := self._config.validate_execution_duration: duration = validator(duration, job) - duration = min(duration, self._config.execution_duration) + if duration: + duration = min(duration, self._config.execution_duration) # Update the duration in the job. if duration == job.execution_duration: return None - await self._storage.update_execution_duration(job_id, duration) - if duration.total_seconds() > 0: + update = JobUpdateMetadata( + destruction_time=job.destruction_time, execution_duration=duration + ) + await self._storage.update_metadata(token, job_id, update) + if duration: duration_str = f"{duration.total_seconds()}s" else: duration_str = "unlimited" - self._logger.info( - "Changed job execution duration", - user=user, - job_id=job_id, - duration=duration_str, - ) + logger.info("Changed job execution duration", duration=duration_str) return duration - def _build_logger_for_job( - self, job: UWSJob, params: ParametersModel | None - ) -> BoundLogger: + def _build_logger_for_job(self, job: Job) -> BoundLogger: """Construct a logger with bound information for a job. Parameters ---------- job Job for which to report messages. - params - Job parameters in model form, if available. Returns ------- BoundLogger Logger with more bound metadata. """ - logger = self._logger.bind(user=job.owner, job_id=job.job_id) + logger = self._logger.bind(user=job.owner, job_id=job.id) if job.run_id: logger = logger.bind(run_id=job.run_id) - if params: - logger = logger.bind(parameters=params.model_dump(mode="json")) + if job.parameters: + parameters = job.parameters.model_dump(mode="json") + logger = logger.bind(parameters=parameters) return logger - def _validate_parameters( - self, params: list[UWSJobParameter] - ) -> ParametersModel: - """Convert UWS job parameters to the parameter model for the service. - - As a side effect, this also verifies that the parameters are valid, - so it is used when creating a job or modifying its parameters to - ensure that the new parameters are valid. - - Parameters - ---------- - params - Job parameters in the UWS job parameter format. - - Returns - ------- - pydantic.BaseModel - Paramters in the model provided by the service, which will be - some subclass of `pydantic.BaseModel`. - - Raises - ------ - safir.uws.UWSError - Raised if there is some problem with the job parameters. - """ - return self._config.parameters_type.from_job_parameters(params) - async def _wait_for_job( - self, job: UWSJob, until_not: set[ExecutionPhase], timeout: timedelta - ) -> UWSJob: + self, + token: str, + job: Job, + until_not: set[ExecutionPhase], + *, + timeout: timedelta, + ) -> Job: """Wait for the completion of a job. Parameters ---------- + token + Delegated token for user. job Job to wait for. until_not @@ -686,9 +651,9 @@ async def _wait_for_job( async with asyncio.timeout(timeout.total_seconds()): while job.phase in until_not: await asyncio.sleep(delay) - job = await self._storage.get(job.job_id) + job = await self._storage.get(token, job.id) delay = min(delay * 1.5, max_delay) # If we timed out, we may have done so in the middle of a delay. Try # one last request. - return await self._storage.get(job.job_id) + return await self._storage.get(token, job.id) diff --git a/safir/src/safir/uws/_storage.py b/safir/src/safir/uws/_storage.py index 035dbb20..7571a080 100644 --- a/safir/src/safir/uws/_storage.py +++ b/safir/src/safir/uws/_storage.py @@ -2,89 +2,38 @@ from __future__ import annotations +from collections.abc import Iterable from datetime import datetime, timedelta -from typing import ParamSpec, TypeVar +from typing import Any -from sqlalchemy import delete -from sqlalchemy.exc import OperationalError -from sqlalchemy.ext.asyncio import async_scoped_session -from sqlalchemy.future import select +from httpx import AsyncClient, HTTPError, Response +from pydantic import BaseModel from vo_models.uws.types import ErrorType, ExecutionPhase -from vo_models.vosi.availability import Availability -from safir.arq import JobMetadata, JobResult -from safir.database import ( - datetime_from_db, - datetime_to_db, - retry_async_transaction, -) -from safir.datetime import current_datetime +from safir.arq import JobMetadata +from safir.arq import JobResult as ArqJobResult +from safir.datetime import current_datetime, isodatetime -from ._exceptions import TaskError, UnknownJobError +from ._config import UWSConfig +from ._exceptions import TaskError, UnknownJobError, WobblyError from ._models import ( - ErrorCode, - UWSJob, - UWSJobDescription, - UWSJobError, - UWSJobParameter, - UWSJobResult, + Job, + JobCreate, + JobError, + JobResult, + JobUpdateAborted, + JobUpdateCompleted, + JobUpdateError, + JobUpdateExecuting, + JobUpdateMetadata, + JobUpdateQueued, + ParametersModel, + SerializedJob, ) -from ._schema import Job as SQLJob -from ._schema import JobParameter as SQLJobParameter -from ._schema import JobResult as SQLJobResult - -T = TypeVar("T") -P = ParamSpec("P") __all__ = ["JobStore"] -def _convert_job(job: SQLJob) -> UWSJob: - """Convert the SQL representation of a job to its dataclass. - - The internal representation of a job uses a dataclass that is kept - intentionally separate from the database schema so that the conversion can - be done explicitly and the rest of the code kept separate from SQLAlchemy - database models. This internal helper function converts from the database - representation to the internal representation. - """ - error = None - if job.error_code and job.error_type and job.error_message: - error = UWSJobError( - error_type=job.error_type, - error_code=job.error_code, - message=job.error_message, - detail=job.error_detail, - ) - return UWSJob( - job_id=str(job.id), - message_id=job.message_id, - owner=job.owner, - phase=job.phase, - run_id=job.run_id, - creation_time=datetime_from_db(job.creation_time), - start_time=datetime_from_db(job.start_time), - end_time=datetime_from_db(job.end_time), - destruction_time=datetime_from_db(job.destruction_time), - execution_duration=timedelta(seconds=job.execution_duration), - quote=job.quote, - parameters=[ - UWSJobParameter(parameter_id=p.parameter, value=p.value) - for p in sorted(job.parameters, key=lambda p: p.id) - ], - results=[ - UWSJobResult( - result_id=r.result_id, - url=r.url, - size=r.size, - mime_type=r.mime_type, - ) - for r in sorted(job.results, key=lambda r: r.sequence) - ], - error=error, - ) - - class JobStore: """Stores and manipulates jobs in the database. @@ -94,33 +43,37 @@ class JobStore: Parameters ---------- - session - The underlying database session. + config + UWS configuration. + http_client + HTTP client to use to talk to Wobbly. """ - def __init__(self, session: async_scoped_session) -> None: - self._session = session + def __init__(self, config: UWSConfig, http_client: AsyncClient) -> None: + self._config = config + self._client = http_client + self._base_url = str(config.wobbly_url).rstrip("/") - async def add( + async def create( self, + token: str, *, - owner: str, run_id: str | None = None, - params: list[UWSJobParameter], + parameters: ParametersModel, execution_duration: timedelta, lifetime: timedelta, - ) -> UWSJob: + ) -> Job: """Create a record of a new job. The job will be created in pending status. Parameters ---------- - owner - The username of the owner of the job. + token + Token for an individual user. run_id A client-supplied opaque identifier to record with the job. - params + parameters The input parameters to the job. execution_duration The maximum length of time for which a job is allowed to run in @@ -131,94 +84,78 @@ async def add( Returns ------- - safir.uws._models.Job - The internal representation of the newly-created job. + Job + Newly-created job. + + Raises + ------ + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - now = current_datetime() - destruction_time = now + lifetime - sql_params = [ - SQLJobParameter(parameter=p.parameter_id, value=p.value) - for p in params - ] - job = SQLJob( - owner=owner, - phase=ExecutionPhase.PENDING, + job_create = JobCreate( + json_parameters=parameters.model_dump(mode="json"), run_id=run_id, - creation_time=datetime_to_db(now), - destruction_time=datetime_to_db(destruction_time), - execution_duration=int(execution_duration.total_seconds()), - parameters=sql_params, - results=[], + destruction_time=current_datetime() + lifetime, + execution_duration=execution_duration, ) - async with self._session.begin(): - self._session.add_all([job, *sql_params]) - await self._session.flush() - return _convert_job(job) + r = await self._request("POST", token, body=job_create) + job = SerializedJob.model_validate(r.json()) + return Job.from_serialized_job(job, self._config.parameters_type) - async def availability(self) -> Availability: - """Check that the database is up.""" - try: - async with self._session.begin(): - await self._session.execute(select(SQLJob.id).limit(1)) - return Availability(available=True) - except OperationalError: - note = "cannot query UWS job database" - return Availability(available=False, note=[note]) - except Exception as e: - note = f"{type(e).__name__}: {e!s}" - return Availability(available=False, note=[note]) - - async def delete(self, job_id: str) -> None: - """Delete a job by ID.""" - stmt = delete(SQLJob).where(SQLJob.id == int(job_id)) - async with self._session.begin(): - await self._session.execute(stmt) - - async def get(self, job_id: str) -> UWSJob: - """Retrieve a job by ID.""" - async with self._session.begin(): - job = await self._get_job(job_id) - return _convert_job(job) - - async def list_expired(self) -> list[UWSJobDescription]: - """Delete all jobs that have passed their destruction time.""" - now = datetime_to_db(current_datetime()) - stmt = select( - SQLJob.id, - SQLJob.message_id, - SQLJob.owner, - SQLJob.phase, - SQLJob.run_id, - SQLJob.creation_time, - ).where(SQLJob.destruction_time <= now) - async with self._session.begin(): - jobs = await self._session.execute(stmt) - return [ - UWSJobDescription( - job_id=str(j.id), - message_id=j.message_id, - owner=j.owner, - phase=j.phase, - run_id=j.run_id, - creation_time=datetime_from_db(j.creation_time), - ) - for j in jobs.all() - ] + async def delete(self, token: str, job_id: str) -> None: + """Delete a job by ID. + + Parameters + ---------- + token + Token for an individual user. + job_id + Job ID to delete. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. + """ + await self._request("DELETE", token, job_id) + + async def get(self, token: str, job_id: str) -> Job: + """Retrieve a job by ID. + + Parameters + ---------- + token + Token for an individual user. + job_id + Job ID to retrieve. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. + """ + r = await self._request("GET", token, job_id) + job = SerializedJob.model_validate(r.json()) + return Job.from_serialized_job(job, self._config.parameters_type) async def list_jobs( self, - user: str, + token: str, *, - phases: list[ExecutionPhase] | None = None, + phases: Iterable[ExecutionPhase] | None = None, after: datetime | None = None, count: int | None = None, - ) -> list[UWSJobDescription]: + ) -> list[SerializedJob]: """List the jobs for a particular user. Parameters ---------- - user - Name of the user whose jobs to load. + token + Token for an individual user. phases Limit the result to jobs in this list of possible execution phases. @@ -229,140 +166,139 @@ async def list_jobs( Returns ------- - list of safir.uws._models.JobDescription + list of SerializedJob List of job descriptions matching the search criteria. + + Raises + ------ + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - stmt = select( - SQLJob.id, - SQLJob.message_id, - SQLJob.owner, - SQLJob.phase, - SQLJob.run_id, - SQLJob.creation_time, - ).where(SQLJob.owner == user) + query: list[tuple[str, str]] = [] if phases: - stmt = stmt.where(SQLJob.phase.in_(phases)) + query.extend(("phase", p.value) for p in phases) if after: - stmt = stmt.where(SQLJob.creation_time > datetime_to_db(after)) - stmt = stmt.order_by(SQLJob.creation_time.desc()) + query.append(("since", isodatetime(after))) if count: - stmt = stmt.limit(count) - async with self._session.begin(): - jobs = await self._session.execute(stmt) - return [ - UWSJobDescription( - job_id=str(j.id), - message_id=j.message_id, - owner=j.owner, - phase=j.phase, - run_id=j.run_id, - creation_time=datetime_from_db(j.creation_time), - ) - for j in jobs.all() - ] - - @retry_async_transaction - async def mark_aborted(self, job_id: str) -> None: + query.append(("limit", str(count))) + r = await self._request("GET", token, query=query) + return [SerializedJob.model_validate(j) for j in r.json()] + + async def mark_aborted(self, token: str, job_id: str) -> None: """Mark a job as aborted. Parameters ---------- + token + Token for an individual user. job_id Identifier of the job. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - async with self._session.begin(): - job = await self._get_job(job_id) - job.phase = ExecutionPhase.ABORTED - if job.start_time: - job.end_time = datetime_to_db(current_datetime()) - - @retry_async_transaction - async def mark_completed(self, job_id: str, job_result: JobResult) -> None: + update = JobUpdateAborted(phase=ExecutionPhase.ABORTED) + await self._request("PATCH", token, job_id, body=update) + + async def mark_completed( + self, token: str, job_id: str, job_result: ArqJobResult + ) -> None: """Mark a job as completed. Parameters ---------- + token + Token for an individual user. job_id Identifier of the job. job_result Result of the job. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - end_time = job_result.finish_time.replace(microsecond=0) - results = job_result.result - if isinstance(results, Exception): - await self.mark_failed(job_id, results, end_time=end_time) + if isinstance(job_result.result, Exception): + await self.mark_failed(token, job_id, job_result.result) return + results = [JobResult.from_worker_result(r) for r in job_result.result] + update = JobUpdateCompleted( + phase=ExecutionPhase.COMPLETED, results=results + ) + await self._request("PATCH", token, job_id, body=update) - async with self._session.begin(): - job = await self._get_job(job_id) - job.end_time = datetime_to_db(end_time) - if job.phase != ExecutionPhase.ABORTED: - job.phase = ExecutionPhase.COMPLETED - for sequence, result in enumerate(results, start=1): - sql_result = SQLJobResult( - job_id=job.id, - result_id=result.result_id, - sequence=sequence, - url=result.url, - size=result.size, - mime_type=result.mime_type, - ) - self._session.add(sql_result) - - @retry_async_transaction async def mark_failed( - self, job_id: str, exc: Exception, *, end_time: datetime | None = None + self, token: str, job_id: str, exc: Exception ) -> None: """Mark a job as failed with an error. + Currently, only one error is supported, even though Wobbly supports + associating multiple errors with a job. + Parameters ---------- + token + Token for an individual user. job_id Identifier of the job. exc Exception of failed job. - end_time - When the job failed, if known. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ if isinstance(exc, TaskError): error = exc.to_job_error() else: - error = UWSJobError( - error_type=ErrorType.FATAL, - error_code=ErrorCode.ERROR, + error = JobError( + type=ErrorType.FATAL, + code="Error", message="Unknown error executing task", detail=f"{type(exc).__name__}: {exc!s}", ) - async with self._session.begin(): - job = await self._get_job(job_id) - job.end_time = datetime_to_db(end_time or current_datetime()) - if job.phase != ExecutionPhase.ABORTED: - job.phase = ExecutionPhase.ERROR - job.error_type = error.error_type - job.error_code = error.error_code - job.error_message = error.message - job.error_detail = error.detail - - @retry_async_transaction - async def mark_executing(self, job_id: str, start_time: datetime) -> None: + update = JobUpdateError(phase=ExecutionPhase.ERROR, errors=[error]) + await self._request("PATCH", token, job_id, body=update) + + async def mark_executing( + self, token: str, job_id: str, start_time: datetime + ) -> None: """Mark a job as executing. Parameters ---------- + token + Token for an individual user. job_id Identifier of the job. start_time Time at which the job started executing. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - start_time = start_time.replace(microsecond=0) - async with self._session.begin(): - job = await self._get_job(job_id) - if job.phase in (ExecutionPhase.PENDING, ExecutionPhase.QUEUED): - job.phase = ExecutionPhase.EXECUTING - job.start_time = datetime_to_db(start_time) - - @retry_async_transaction - async def mark_queued(self, job_id: str, metadata: JobMetadata) -> None: + update = JobUpdateExecuting( + phase=ExecutionPhase.EXECUTING, start_time=start_time + ) + await self._request("PATCH", token, job_id, body=update) + + async def mark_queued( + self, token: str, job_id: str, metadata: JobMetadata + ) -> None: """Mark a job as queued for processing. This is called by the web frontend after queuing the work. However, @@ -371,54 +307,104 @@ async def mark_queued(self, job_id: str, metadata: JobMetadata) -> None: Parameters ---------- + token + Token for an individual user. job_id Identifier of the job. metadata Metadata about the underlying arq job. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - async with self._session.begin(): - job = await self._get_job(job_id) - job.message_id = metadata.id - if job.phase in (ExecutionPhase.PENDING, ExecutionPhase.HELD): - job.phase = ExecutionPhase.QUEUED - - async def update_destruction( - self, job_id: str, destruction: datetime + update = JobUpdateQueued( + phase=ExecutionPhase.QUEUED, message_id=metadata.id + ) + await self._request("PATCH", token, job_id, body=update) + + async def update_metadata( + self, + token: str, + job_id: str, + metadata: JobUpdateMetadata, ) -> None: - """Update the destruction time of a job. + """Update the destruction time or execution duration of a job. Parameters ---------- + token + Token for an individual user. job_id Identifier of the job. - destruction - New destruction time. + metadata + New job metadata. + + Raises + ------ + UnknownJobError + Raised if the job was not found. + WobblyError + Raised if the Wobbly request fails or returns a failure status. """ - destruction = destruction.replace(microsecond=0) - async with self._session.begin(): - job = await self._get_job(job_id) - job.destruction_time = datetime_to_db(destruction) + await self._request("PATCH", token, job_id, body=metadata) - async def update_execution_duration( - self, job_id: str, execution_duration: timedelta - ) -> None: - """Update the destruction time of a job. + async def _request( + self, + method: str, + token: str, + job_id: str | None = None, + *, + body: BaseModel | None = None, + query: list[tuple[str, str]] | None = None, + ) -> Response: + """Send an HTTP request to Wobbly. Parameters ---------- + method + HTTP method. + token + Token for an individual user. job_id - Identifier of the job. - execution_duration - New execution duration. + Identifier of job to act on. + body + If given, a Pydantic model that should be serialized bo create the + JSON body of the request. + query + If given, query parameters to send. + + Returns + ------- + Response + HTTP response object. + + Raises + ------ + UnknownJobError + Raised if the request was for a specific job and that job was not + found. + WobblyError + Raised if the HTTP request fails or returns a failure status. """ - async with self._session.begin(): - job = await self._get_job(job_id) - job.execution_duration = int(execution_duration.total_seconds()) - - async def _get_job(self, job_id: str) -> SQLJob: - """Retrieve a job from the database by job ID.""" - stmt = select(SQLJob).where(SQLJob.id == int(job_id)) - job = (await self._session.execute(stmt)).scalar_one_or_none() - if not job: - raise UnknownJobError(job_id) - return job + kwargs: dict[str, Any] = { + "headers": {"Authorization": f"bearer {token}"} + } + if body: + kwargs["json"] = body.model_dump(mode="json") + if query: + kwargs["params"] = query + url = self._base_url + "/jobs" + if job_id: + url += "/" + job_id + try: + r = await self._client.request(method, url, **kwargs) + if r.status_code == 404 and job_id: + raise UnknownJobError(job_id) + r.raise_for_status() + except HTTPError as e: + raise WobblyError.from_exception(e) from e + return r diff --git a/safir/src/safir/uws/_workers.py b/safir/src/safir/uws/_workers.py index 79c35ff2..48f94caf 100644 --- a/safir/src/safir/uws/_workers.py +++ b/safir/src/safir/uws/_workers.py @@ -6,10 +6,9 @@ import contextlib import uuid from datetime import UTC, datetime -from pathlib import Path from typing import Any, ParamSpec -from sqlalchemy.ext.asyncio import async_scoped_session +from httpx import AsyncClient from structlog.stdlib import BoundLogger from safir.arq import ( @@ -22,20 +21,15 @@ RedisArqQueue, ) from safir.arq.uws import WorkerError, WorkerTransientError -from safir.database import ( - create_async_session, - create_database_engine, - is_database_current, -) from safir.datetime import format_datetime_for_logging from safir.dependencies.http_client import http_client_dependency from safir.slack.blockkit import SlackException from safir.slack.webhook import SlackIgnoredException, SlackWebhookClient from ._config import UWSConfig -from ._constants import JOB_RESULT_TIMEOUT -from ._exceptions import DatabaseSchemaError, TaskError, UnknownJobError -from ._models import UWSJob +from ._constants import JOB_RESULT_TIMEOUT, WOBBLY_REQUEST_TIMEOUT +from ._exceptions import TaskError, UnknownJobError +from ._models import Job from ._service import JobService from ._storage import JobStore @@ -44,18 +38,13 @@ __all__ = [ "close_uws_worker_context", "create_uws_worker_context", - "uws_expire_jobs", "uws_job_completed", "uws_job_started", ] async def create_uws_worker_context( - config: UWSConfig, - logger: BoundLogger, - *, - check_schema: bool = False, - alembic_config_path: Path = Path("alembic.ini"), + config: UWSConfig, logger: BoundLogger ) -> dict[str, Any]: """Construct the arq context for UWS workers. @@ -68,21 +57,11 @@ async def create_uws_worker_context( UWS configuration. logger Logger for the worker to use. - check_schema - Whether to check the database schema version with Alembic on startup. - alembic_config_path - When checking the schema, use this path to the Alembic - configuration. Returns ------- dict Keys to add to the ``ctx`` dictionary. - - Raises - ------ - DatabaseSchemaError - Raised if the UWS database schema is out of date. """ logger = logger.bind(worker_instance=uuid.uuid4().hex) @@ -95,16 +74,8 @@ async def create_uws_worker_context( else: arq = MockArqQueue() - engine = create_database_engine( - config.database_url, - config.database_password, - isolation_level="REPEATABLE READ", - ) - if check_schema: - if not await is_database_current(engine, logger, alembic_config_path): - raise DatabaseSchemaError("UWS database schema out of date") - session = await create_async_session(engine, logger) - storage = JobStore(session) + http_client = AsyncClient(timeout=WOBBLY_REQUEST_TIMEOUT) + storage = JobStore(config, http_client) service = JobService( config=config, arq_queue=arq, storage=storage, logger=logger ) @@ -119,9 +90,9 @@ async def create_uws_worker_context( logger.info("Worker startup complete") return { "arq": arq, + "http_client": http_client, "logger": logger, "service": service, - "session": session, "slack": slack, "storage": storage, } @@ -139,9 +110,9 @@ async def close_uws_worker_context(ctx: dict[Any, Any]) -> None: Worker context. """ logger: BoundLogger = ctx["logger"] - session: async_scoped_session = ctx["session"] + http_client: AsyncClient = ctx["http_client"] - await session.remove() + await http_client.aclose() # Possibly initialized by the Slack webhook client. await http_client_dependency.aclose() @@ -149,27 +120,8 @@ async def close_uws_worker_context(ctx: dict[Any, Any]) -> None: logger.info("Worker shutdown complete") -async def uws_expire_jobs(ctx: dict[Any, Any]) -> None: - """Delete jobs that have passed their destruction time. - - Parameters - ---------- - ctx - arq context. - """ - slack: SlackWebhookClient | None = ctx["slack"] - service: JobService = ctx["service"] - - try: - await service.delete_expired() - except Exception as e: - if slack: - await slack.post_uncaught_exception(e) - raise - - async def uws_job_started( - ctx: dict[Any, Any], job_id: str, start_time: datetime + ctx: dict[Any, Any], token: str, job_id: str, start_time: datetime ) -> None: """Mark a UWS job as executing. @@ -177,6 +129,8 @@ async def uws_job_started( ---------- ctx arq context. + token + Token for the user executing the job. job_id UWS job identifier. start_time @@ -187,7 +141,7 @@ async def uws_job_started( storage: JobStore = ctx["storage"] try: - await storage.mark_executing(job_id, start_time) + await storage.mark_executing(token, job_id, start_time) logger.info( "Marked job as started", start_time=format_datetime_for_logging(start_time), @@ -201,7 +155,7 @@ async def uws_job_started( async def _annotate_worker_error( - exc: Exception, job: UWSJob, slack: SlackWebhookClient | None = None + exc: Exception, job: Job, slack: SlackWebhookClient | None = None ) -> Exception: """Convert and possibly report a backend worker error. @@ -227,7 +181,7 @@ async def _annotate_worker_error( match exc: case WorkerError(): error = TaskError.from_worker_error(exc) - error.job_id = job.job_id + error.job_id = job.id error.started_at = job.creation_time error.user = job.owner if slack and not error.slack_ignore: @@ -259,7 +213,9 @@ async def _get_job_result(arq: ArqQueue, arq_job_id: str) -> JobResult: return await arq.get_job_result(arq_job_id) -async def uws_job_completed(ctx: dict[Any, Any], job_id: str) -> None: +async def uws_job_completed( + ctx: dict[Any, Any], token: str, job_id: str +) -> None: """Mark a UWS job as completed. Recover the exception if the job failed and record that as the job error. @@ -270,6 +226,8 @@ async def uws_job_completed(ctx: dict[Any, Any], job_id: str) -> None: ---------- ctx arq context. + token + Token for the user executing the job. job_id UWS job identifier. """ @@ -281,7 +239,7 @@ async def uws_job_completed(ctx: dict[Any, Any], job_id: str) -> None: storage: JobStore = ctx["storage"] try: - job = await storage.get(job_id) + job = await storage.get(token, job_id) arq_job_id = job.message_id if not arq_job_id: msg = "Job has no associated arq job ID, cannot mark completed" @@ -301,7 +259,7 @@ async def uws_job_completed(ctx: dict[Any, Any], job_id: str) -> None: ) exc.__cause__ = e error = await _annotate_worker_error(exc, job, slack) - await storage.mark_failed(job_id, error) + await storage.mark_failed(token, job_id, error) return # If the job failed and Slack reporting is enabled, annotate the job @@ -311,7 +269,7 @@ async def uws_job_completed(ctx: dict[Any, Any], job_id: str) -> None: result.result = error # Mark the job as completed. - await storage.mark_completed(job_id, result) + await storage.mark_completed(token, job_id, result) logger.info("Marked job as completed") except UnknownJobError: logger.warning("Job not found to mark as completed") diff --git a/safir/src/safir/uws/templates/error.xml b/safir/src/safir/uws/templates/error.xml index 5fcddef6..fc7a5183 100644 --- a/safir/src/safir/uws/templates/error.xml +++ b/safir/src/safir/uws/templates/error.xml @@ -1,7 +1,7 @@ -{{ error.error_code.value }}: {{ error.message }} +{{ error.code }}: {{ error.message }} {%- if error.detail %} {{ error.detail }} diff --git a/safir/tests/data/database/uws/README.md b/safir/tests/data/database/uws/README.md deleted file mode 100644 index bb5d94d7..00000000 --- a/safir/tests/data/database/uws/README.md +++ /dev/null @@ -1,9 +0,0 @@ -The migration in this directory was generated with the following commands: - -```shell -docker-compose -f ../docker-compose.yaml up -env TEST_DATABASE_URL=postgresql://example@localhost/example \ - TEST_DATABASE_PASSWORD=INSECURE \ - PYTHONPATH=$(pwd)/../../../.. \ - alembic revision --autogenerate -m "UWS schema" -``` diff --git a/safir/tests/data/database/uws/alembic.ini b/safir/tests/data/database/uws/alembic.ini deleted file mode 100644 index 9b1acd83..00000000 --- a/safir/tests/data/database/uws/alembic.ini +++ /dev/null @@ -1,17 +0,0 @@ -# Keep in sync with the instructions in the Safir user guide. - -[alembic] -script_location = %(here)s/alembic -file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s -prepend_sys_path = . -timezone = UTC -version_path_separator = os - -[post_write_hooks] -hooks = ruff ruff_format -ruff.type = exec -ruff.executable = ruff -ruff.options = check --fix REVISION_SCRIPT_FILENAME -ruff_format.type = exec -ruff_format.executable = ruff -ruff_format.options = format REVISION_SCRIPT_FILENAME diff --git a/safir/tests/data/database/uws/alembic/env.py b/safir/tests/data/database/uws/alembic/env.py deleted file mode 100644 index 3bab44ff..00000000 --- a/safir/tests/data/database/uws/alembic/env.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Alembic migration environment. - -Keep in sync with the instructions in the Safir user guide. -""" - -from alembic import context - -from safir.database import run_migrations_offline, run_migrations_online -from safir.logging import configure_alembic_logging, configure_logging -from safir.uws import UWSSchemaBase -from tests.support.alembic import config - -# Configure structlog. -configure_logging(name="tests", log_level=config.log_level) -configure_alembic_logging() - -# Run the migrations. -if context.is_offline_mode(): - run_migrations_offline(UWSSchemaBase.metadata, config.database_url) -else: - run_migrations_online( - UWSSchemaBase.metadata, - config.database_url, - config.database_password, - ) diff --git a/safir/tests/data/database/uws/alembic/script.py.mako b/safir/tests/data/database/uws/alembic/script.py.mako deleted file mode 100644 index fbc4b07d..00000000 --- a/safir/tests/data/database/uws/alembic/script.py.mako +++ /dev/null @@ -1,26 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - ${downgrades if downgrades else "pass"} diff --git a/safir/tests/data/database/uws/alembic/versions/20240911_0000_e9299566bc19_uws_schema.py b/safir/tests/data/database/uws/alembic/versions/20240911_0000_e9299566bc19_uws_schema.py deleted file mode 100644 index 9e876765..00000000 --- a/safir/tests/data/database/uws/alembic/versions/20240911_0000_e9299566bc19_uws_schema.py +++ /dev/null @@ -1,127 +0,0 @@ -"""UWS schema - -Revision ID: e9299566bc19 -Revises: -Create Date: 2024-09-11 00:00:57.336783+00:00 - -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "e9299566bc19" -down_revision: str | None = None -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "job", - sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("message_id", sa.String(length=64), nullable=True), - sa.Column("owner", sa.String(length=64), nullable=False), - sa.Column( - "phase", - sa.Enum( - "PENDING", - "QUEUED", - "EXECUTING", - "COMPLETED", - "ERROR", - "UNKNOWN", - "HELD", - "SUSPENDED", - "ABORTED", - "ARCHIVED", - name="executionphase", - ), - nullable=False, - ), - sa.Column("run_id", sa.String(length=64), nullable=True), - sa.Column("creation_time", sa.DateTime(), nullable=False), - sa.Column("start_time", sa.DateTime(), nullable=True), - sa.Column("end_time", sa.DateTime(), nullable=True), - sa.Column("destruction_time", sa.DateTime(), nullable=False), - sa.Column("execution_duration", sa.Integer(), nullable=False), - sa.Column("quote", sa.DateTime(), nullable=True), - sa.Column( - "error_type", - sa.Enum("TRANSIENT", "FATAL", name="errortype"), - nullable=True, - ), - sa.Column( - "error_code", - sa.Enum( - "AUTHENTICATION_ERROR", - "AUTHORIZATION_ERROR", - "MULTIVALUED_PARAM_NOT_SUPPORTED", - "ERROR", - "SERVICE_UNAVAILABLE", - "USAGE_ERROR", - name="errorcode", - ), - nullable=True, - ), - sa.Column("error_message", sa.Text(), nullable=True), - sa.Column("error_detail", sa.Text(), nullable=True), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - "by_owner_phase", - "job", - ["owner", "phase", "creation_time"], - unique=False, - ) - op.create_index( - "by_owner_time", "job", ["owner", "creation_time"], unique=False - ) - op.create_table( - "job_parameter", - sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("job_id", sa.Integer(), nullable=False), - sa.Column("parameter", sa.String(length=64), nullable=False), - sa.Column("value", sa.Text(), nullable=False), - sa.Column("is_post", sa.Boolean(), nullable=False), - sa.ForeignKeyConstraint(["job_id"], ["job.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - "by_parameter", "job_parameter", ["job_id", "parameter"], unique=False - ) - op.create_table( - "job_result", - sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("job_id", sa.Integer(), nullable=False), - sa.Column("result_id", sa.String(length=64), nullable=False), - sa.Column("sequence", sa.Integer(), nullable=False), - sa.Column("url", sa.String(length=256), nullable=False), - sa.Column("size", sa.Integer(), nullable=True), - sa.Column("mime_type", sa.String(length=64), nullable=True), - sa.ForeignKeyConstraint(["job_id"], ["job.id"], ondelete="CASCADE"), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - "by_result_id", "job_result", ["job_id", "result_id"], unique=True - ) - op.create_index( - "by_sequence", "job_result", ["job_id", "sequence"], unique=True - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index("by_sequence", table_name="job_result") - op.drop_index("by_result_id", table_name="job_result") - op.drop_table("job_result") - op.drop_index("by_parameter", table_name="job_parameter") - op.drop_table("job_parameter") - op.drop_index("by_owner_time", table_name="job") - op.drop_index("by_owner_phase", table_name="job") - op.drop_table("job") - # ### end Alembic commands ### diff --git a/safir/tests/support/uws.py b/safir/tests/support/uws.py index 6d358709..dbc16050 100644 --- a/safir/tests/support/uws.py +++ b/safir/tests/support/uws.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import timedelta -from typing import Annotated, Self +from typing import Annotated from arq.connections import RedisSettings from fastapi import Form, Query @@ -11,7 +11,7 @@ from vo_models.uws import JobSummary, Parameter, Parameters from safir.arq import ArqMode -from safir.uws import ParametersModel, UWSConfig, UWSJobParameter, UWSRoute +from safir.uws import ParametersModel, UWSConfig, UWSRoute __all__ = [ "SimpleParameters", @@ -33,12 +33,6 @@ class SimpleParameters( ): name: str - @classmethod - def from_job_parameters(cls, params: list[UWSJobParameter]) -> Self: - assert len(params) == 1 - assert params[0].parameter_id == "name" - return cls(name=params[0].value) - def to_worker_parameters(self) -> SimpleWorkerParameters: return SimpleWorkerParameters(name=self.name) @@ -48,17 +42,17 @@ def to_xml_model(self) -> SimpleXmlParameters: async def _get_dependency( name: Annotated[str, Query()], -) -> list[UWSJobParameter]: - return [UWSJobParameter(parameter_id="name", value=name)] +) -> SimpleParameters: + return SimpleParameters(name=name) async def _post_dependency( name: Annotated[str, Form()], -) -> list[UWSJobParameter]: - return [UWSJobParameter(parameter_id="name", value=name)] +) -> SimpleParameters: + return SimpleParameters(name=name) -def build_uws_config(database_url: str, database_password: str) -> UWSConfig: +def build_uws_config() -> UWSConfig: """Set up a test configuration.""" return UWSConfig( arq_mode=ArqMode.test, @@ -66,8 +60,6 @@ def build_uws_config(database_url: str, database_password: str) -> UWSConfig: async_post_route=UWSRoute( dependency=_post_dependency, summary="Create async job" ), - database_url=database_url, - database_password=SecretStr(database_password), execution_duration=timedelta(minutes=10), job_summary_type=JobSummary[SimpleXmlParameters], lifetime=timedelta(days=1), @@ -80,5 +72,6 @@ def build_uws_config(database_url: str, database_password: str) -> UWSConfig: sync_post_route=UWSRoute( dependency=_post_dependency, summary="Sync request" ), + wobbly_url="https://example.com/wobbly", worker="hello", ) diff --git a/safir/tests/uws/alembic_test.py b/safir/tests/uws/alembic_test.py deleted file mode 100644 index 8e35d39d..00000000 --- a/safir/tests/uws/alembic_test.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Test UWS integration with Alembic.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any - -import pytest -from structlog.stdlib import BoundLogger - -from safir.uws import DatabaseSchemaError, UWSApplication, UWSConfig - - -@pytest.mark.asyncio -async def test_database_init( - uws_config: UWSConfig, logger: BoundLogger -) -> None: - uws = UWSApplication(uws_config) - config_path = ( - Path(__file__).parent.parent - / "data" - / "database" - / "uws" - / "alembic.ini" - ) - worker_settings = uws.build_worker( - logger, check_schema=True, alembic_config_path=config_path - ) - assert worker_settings.on_startup - assert worker_settings.on_shutdown - worker_ctx: dict[Any, Any] = {} - worker_startup = worker_settings.on_startup - worker_shutdown = worker_settings.on_shutdown - - # Initialize the database without Alembic. - await uws.initialize_uws_database(logger, reset=True) - - # Initializing a FastAPI app, or creating a UWS worker, should both fail - # because the database is not current. - assert not await uws.is_schema_current(logger, config_path) - with pytest.raises(DatabaseSchemaError): - await uws.initialize_fastapi( - logger, check_schema=True, alembic_config_path=config_path - ) - with pytest.raises(DatabaseSchemaError): - await worker_startup(worker_ctx) - - # Reinitialize the database with Alembic. - await uws.initialize_uws_database( - logger, - reset=True, - use_alembic=True, - alembic_config_path=config_path, - ) - - # Other initializations should now succeed. - assert await uws.is_schema_current(logger, config_path) - await uws.initialize_fastapi( - logger, check_schema=True, alembic_config_path=config_path - ) - await worker_startup(worker_ctx) - - # Clean up those initializations. - await uws.shutdown_fastapi() - await worker_shutdown(worker_ctx) diff --git a/safir/tests/uws/conftest.py b/safir/tests/uws/conftest.py index 2277fc1f..3906cedd 100644 --- a/safir/tests/uws/conftest.py +++ b/safir/tests/uws/conftest.py @@ -5,14 +5,13 @@ from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager from datetime import timedelta -from typing import Annotated import pytest import pytest_asyncio import respx import structlog from asgi_lifespan import LifespanManager -from fastapi import APIRouter, Body, FastAPI +from fastapi import APIRouter, FastAPI from httpx import ASGITransport, AsyncClient from structlog.stdlib import BoundLogger @@ -22,37 +21,19 @@ from safir.slack.webhook import SlackRouteErrorHandler from safir.testing.gcs import MockStorageClient, patch_google_storage from safir.testing.slack import MockSlackWebhook, mock_slack_webhook -from safir.testing.uws import MockUWSJobRunner -from safir.uws import UWSApplication, UWSConfig, UWSJobParameter +from safir.testing.uws import MockUWSJobRunner, MockWobbly, patch_wobbly +from safir.uws import UWSApplication, UWSConfig from safir.uws._dependencies import UWSFactory, uws_dependency from ..support.uws import build_uws_config -@pytest.fixture -def post_params_router() -> APIRouter: - """Return a router that echoes the parameters passed in the request.""" - router = APIRouter() - - @router.post("/params") - async def post_params( - params: Annotated[list[UWSJobParameter], Body()], - ) -> dict[str, list[dict[str, str]]]: - return { - "params": [ - {"id": p.parameter_id, "value": p.value} for p in params - ] - } - - return router - - @pytest_asyncio.fixture async def app( arq_queue: MockArqQueue, + mock_wobbly: MockWobbly, uws_config: UWSConfig, logger: BoundLogger, - post_params_router: APIRouter, ) -> AsyncIterator[FastAPI]: """Return a configured test application for UWS. @@ -61,7 +42,6 @@ async def app( the pieces added by an application. """ uws = UWSApplication(uws_config) - await uws.initialize_uws_database(logger, reset=True) uws.override_arq_queue(arq_queue) @asynccontextmanager @@ -76,7 +56,6 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: router = APIRouter(route_class=SlackRouteErrorHandler) uws.install_handlers(router) app.include_router(router, prefix="/test") - app.include_router(post_params_router, prefix="/test") uws.install_error_handlers(app) async with LifespanManager(app): @@ -89,13 +68,17 @@ def arq_queue() -> MockArqQueue: @pytest_asyncio.fixture -async def client(app: FastAPI) -> AsyncIterator[AsyncClient]: +async def client( + app: FastAPI, test_token: str, test_username: str +) -> AsyncIterator[AsyncClient]: """Return an ``httpx.AsyncClient`` configured to talk to the test app.""" - transport = ASGITransport(app=app) async with AsyncClient( - transport=transport, + transport=ASGITransport(app=app), base_url="https://example.com/", - headers={"X-Auth-Request-Token": "sometoken"}, + headers={ + "X-Auth-Request-Token": test_token, + "X-Auth-Request-User": test_username, + }, ) as client: yield client @@ -123,22 +106,38 @@ def mock_slack( ) +@pytest.fixture +def mock_wobbly(respx_mock: respx.Router, uws_config: UWSConfig) -> MockWobbly: + return patch_wobbly(respx_mock, str(uws_config.wobbly_url)) + + @pytest_asyncio.fixture async def runner( uws_config: UWSConfig, arq_queue: MockArqQueue -) -> AsyncIterator[MockUWSJobRunner]: - async with MockUWSJobRunner(uws_config, arq_queue) as runner: - yield runner +) -> MockUWSJobRunner: + return MockUWSJobRunner(uws_config, arq_queue) + + +@pytest.fixture +def test_service() -> str: + return "test-service" + + +@pytest.fixture +def test_token(test_service: str, test_username: str) -> str: + return MockWobbly.make_token(test_service, test_username) + + +@pytest.fixture +def test_username() -> str: + return "test-user" @pytest.fixture -def uws_config(database_url: str, database_password: str) -> UWSConfig: - return build_uws_config(database_url, database_password) +def uws_config() -> UWSConfig: + return build_uws_config() @pytest_asyncio.fixture -async def uws_factory( - app: FastAPI, logger: BoundLogger -) -> AsyncIterator[UWSFactory]: - async for factory in uws_dependency(logger): - yield factory +async def uws_factory(app: FastAPI, logger: BoundLogger) -> UWSFactory: + return await uws_dependency(AsyncClient(), logger) diff --git a/safir/tests/uws/errors_test.py b/safir/tests/uws/errors_test.py index a4fe1e00..fd91cf5f 100644 --- a/safir/tests/uws/errors_test.py +++ b/safir/tests/uws/errors_test.py @@ -8,9 +8,11 @@ from httpx import AsyncClient from safir.testing.slack import MockSlackWebhook -from safir.uws import UWSJobParameter +from safir.testing.uws import MockWobbly from safir.uws._dependencies import UWSFactory +from ..support.uws import SimpleParameters + @dataclass class PostTest: @@ -22,16 +24,19 @@ class PostTest: @pytest.mark.asyncio async def test_errors( - client: AsyncClient, uws_factory: UWSFactory, mock_slack: MockSlackWebhook + client: AsyncClient, + test_token: str, + test_service: str, + uws_factory: UWSFactory, + mock_slack: MockSlackWebhook, + mock_wobbly: MockWobbly, ) -> None: job_service = uws_factory.create_job_service() await job_service.create( - "user", - run_id="some-run-id", - params=[UWSJobParameter(parameter_id="name", value="June")], + test_token, SimpleParameters(name="June"), run_id="some-run-id" ) - # No user specified. + # No token provided. routes = [ "/test/jobs/1", "/test/jobs/1/destruction", @@ -44,21 +49,24 @@ async def test_errors( "/test/jobs/1/results", ] for route in routes: - r = await client.get(route) + request = client.build_request("GET", route) + del request.headers["X-Auth-Request-Token"] + r = await client.send(request) assert r.status_code == 422 assert r.text.startswith("UsageError") # Wrong user specified. + other_token = MockWobbly.make_token(test_service, "other-user") for route in routes: r = await client.get( - route, headers={"X-Auth-Request-User": "otheruser"} + route, headers={"X-Auth-Request-Token": other_token} ) - assert r.status_code == 403 - assert r.text.startswith("AuthorizationError") + assert r.status_code == 404 + assert r.text.startswith("UsageError") # Job does not exist. for route in (r.replace("/1", "/2") for r in routes): - r = await client.get(route, headers={"X-Auth-Request-User": "user"}) + r = await client.get(route) assert r.status_code == 404 assert r.text.startswith("UsageError") @@ -74,7 +82,9 @@ async def test_errors( PostTest("/test/jobs/1/phase", {"phase": "RUN"}), ] for test in tests: - r = await client.post(test.url, data=test.data) + request = client.build_request("POST", test.url, data=test.data) + del request.headers["X-Auth-Request-Token"] + r = await client.send(request) assert r.status_code == 422 assert r.text.startswith("UsageError") @@ -83,34 +93,32 @@ async def test_errors( r = await client.post( test.url, data=test.data, - headers={"X-Auth-Request-User": "otheruser"}, + headers={"X-Auth-Request-Token": other_token}, ) - assert r.status_code == 403 - assert r.text.startswith("AuthorizationError") + assert r.status_code == 404 + assert r.text.startswith("UsageError") # Job does not exist. for test in tests: url = test.url.replace("/1", "/2") - r = await client.post( - url, data=test.data, headers={"X-Auth-Request-User": "user"} - ) + r = await client.post(url, data=test.data) assert r.status_code == 404 assert r.text.startswith("UsageError") # Finally, test all the same things with the one supported DELETE. - r = await client.delete("/test/jobs/1") + request = client.build_request("DELETE", "/test/jobs/1") + del request.headers["X-Auth-Request-Token"] + r = await client.send(request) assert r.status_code == 422 assert r.text.startswith("UsageError") r = await client.delete( - "/test/jobs/1", headers={"X-Auth-Request-User": "otheruser"} - ) - assert r.status_code == 403 - assert r.text.startswith("AuthorizationError") - r = await client.delete( - "/test/jobs/2", headers={"X-Auth-Request-User": "user"} + "/test/jobs/1", headers={"X-Auth-Request-Token": other_token} ) assert r.status_code == 404 assert r.text.startswith("UsageError") + r = await client.delete("/test/jobs/2") + assert r.status_code == 404 + assert r.text.startswith("UsageError") # Try some bogus destruction and execution duration parameters. tests = [ @@ -123,8 +131,6 @@ async def test_errors( "/test/jobs/1/destruction", {"destrucTION": "2021-09-10T10:01:02+00:00:00"}, ), - PostTest("/test/jobs/1/executionduration", {"executionduration": "0"}), - PostTest("/test/jobs/1/executionduration", {"executionDUration": "0"}), PostTest( "/test/jobs/1/executionduration", {"executionduration": "-1"} ), @@ -139,22 +145,18 @@ async def test_errors( ), ] for test in tests: - r = await client.post( - test.url, data=test.data, headers={"X-Auth-Request-User": "user"} - ) + r = await client.post(test.url, data=test.data) assert r.status_code == 422, f"{test.url} {test.data}" assert r.text.startswith("UsageError"), r.text # Test bogus phase for async job creation. r = await client.post( "/test/jobs?phase=START", - headers={"X-Auth-Request-User": "user"}, data={"runid": "some-run-id", "name": "Jane"}, ) assert r.status_code == 422 r = await client.post( "/test/jobs", - headers={"X-Auth-Request-User": "user"}, data={"runid": "some-run-id", "name": "Jane", "phase": "START"}, ) assert r.status_code == 422 diff --git a/safir/tests/uws/job_api_test.py b/safir/tests/uws/job_api_test.py index df62c0f4..3d3621c6 100644 --- a/safir/tests/uws/job_api_test.py +++ b/safir/tests/uws/job_api_test.py @@ -18,13 +18,13 @@ from vo_models.uws import JobSummary, Results from safir.arq import MockArqQueue -from safir.arq.uws import WorkerJobInfo +from safir.arq.uws import WorkerJobInfo, WorkerResult from safir.datetime import current_datetime, isodatetime from safir.testing.uws import MockUWSJobRunner, assert_job_summary_equal -from safir.uws import UWSConfig, UWSJob, UWSJobParameter, UWSJobResult +from safir.uws import Job, UWSConfig from safir.uws._dependencies import UWSFactory -from ..support.uws import SimpleXmlParameters +from ..support.uws import SimpleParameters, SimpleXmlParameters PENDING_JOB = """ {} some-run-id - user + test-user {} {} {} @@ -57,7 +57,7 @@ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> 1 some-run-id - user + test-user COMPLETED {} {} @@ -84,7 +84,7 @@ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> {} some-run-id - user + test-user ABORTED {} 600 @@ -105,7 +105,7 @@ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> {} some-run-id - user + test-user ABORTED {} {} @@ -146,6 +146,8 @@ @pytest.mark.asyncio async def test_job_run( client: AsyncClient, + test_token: str, + test_username: str, runner: MockUWSJobRunner, uws_factory: UWSFactory, uws_config: UWSConfig, @@ -155,18 +157,15 @@ async def test_job_run( # Create the job. r = await client.post( "/test/jobs", - headers={"X-Auth-Request-User": "user"}, data={"runid": "some-run-id", "name": "Jane"}, ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" - job = await job_service.get("user", "1") + job = await job_service.get(test_token, "1") assert job.creation_time.microsecond == 0 # Check the retrieval of the job configuration. - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" assert_job_summary_equal( @@ -182,20 +181,13 @@ async def test_job_run( ) # Try to put the job in an invalid phase. - r = await client.post( - "/test/jobs/1/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "EXECUTING"}, - ) + r = await client.post("/test/jobs/1/phase", data={"PHASE": "EXECUTING"}) assert r.status_code == 422 assert r.text.startswith("UsageError") # Start the job. r = await client.post( - "/test/jobs/1/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "RUN"}, - follow_redirects=True, + "/test/jobs/1/phase", data={"PHASE": "RUN"}, follow_redirects=True ) assert r.status_code == 200 assert r.url == "https://example.com/test/jobs/1" @@ -210,16 +202,16 @@ async def test_job_run( isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)), ), ) - await runner.mark_in_progress("user", "1") + await runner.mark_in_progress(test_token, "1") # Check that the correct data was passed to the backend worker. - metadata = await runner.get_job_metadata("user", "1") + metadata = await runner.get_job_metadata(test_token, "1") assert metadata.name == uws_config.worker assert metadata.args[0] == {"name": "Jane"} assert metadata.args[1] == WorkerJobInfo( job_id="1", - user="user", - token="sometoken", + user=test_username, + token=test_token, timeout=ANY, run_id="some-run-id", ) @@ -231,13 +223,13 @@ async def test_job_run( # Tell the queue the job is finished. results = [ - UWSJobResult( + WorkerResult( result_id="cutout", url="s3://some-bucket/some/path", mime_type="application/fits", ) ] - job = await runner.mark_complete("user", "1", results) + job = await runner.mark_complete(test_token, "1", results) # Check the job results. assert job.start_time @@ -245,9 +237,7 @@ async def test_job_run( assert job.end_time assert job.end_time.microsecond == 0 assert job.end_time >= job.start_time >= job.creation_time - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" assert_job_summary_equal( @@ -263,31 +253,26 @@ async def test_job_run( ) # Check that the phase is now correct. - r = await client.get( - "/test/jobs/1/phase", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/phase") assert r.status_code == 200 assert r.headers["Content-Type"] == "text/plain; charset=utf-8" assert r.text == "COMPLETED" # Retrieve them directly through the results resource. - r = await client.get( - "/test/jobs/1/results", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/results") assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" assert Results.from_xml(r.text) == Results.from_xml(JOB_RESULTS) # There should be no error message. - r = await client.get( - "/test/jobs/1/error", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/error") assert r.status_code == 404 @pytest.mark.asyncio async def test_job_abort( client: AsyncClient, + test_token: str, runner: MockUWSJobRunner, arq_queue: MockArqQueue, uws_factory: UWSFactory, @@ -297,19 +282,14 @@ async def test_job_abort( # Create the job. r = await client.post( - "/test/jobs", - headers={"X-Auth-Request-User": "user"}, - data={"runid": "some-run-id", "name": "Jane"}, + "/test/jobs", data={"runid": "some-run-id", "name": "Jane"} ) assert r.status_code == 303 - job = await job_service.get("user", "1") + job = await job_service.get(test_token, "1") # Immediately abort the job. r = await client.post( - "/test/jobs/1/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "ABORT"}, - follow_redirects=True, + "/test/jobs/1/phase", data={"PHASE": "ABORT"}, follow_redirects=True ) assert r.status_code == 200 assert r.url == "https://example.com/test/jobs/1" @@ -326,23 +306,19 @@ async def test_job_abort( # Create a second job and start it running. r = await client.post( "/test/jobs", - headers={"X-Auth-Request-User": "user"}, data={"runid": "some-run-id", "name": "Jane", "phase": "RUN"}, ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/2" - await runner.mark_in_progress("user", "2") + await runner.mark_in_progress(test_token, "2") # Abort that job. r = await client.post( - "/test/jobs/2/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "ABORT"}, - follow_redirects=True, + "/test/jobs/2/phase", data={"PHASE": "ABORT"}, follow_redirects=True ) assert r.status_code == 200 assert r.url == "https://example.com/test/jobs/2" - job = await job_service.get("user", "2") + job = await job_service.get(test_token, "2") assert job.start_time assert job.end_time assert_job_summary_equal( @@ -356,7 +332,7 @@ async def test_job_abort( isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)), ), ) - job_result = await runner.get_job_result("user", "2") + job_result = await runner.get_job_result(test_token, "2") assert not job_result.success assert isinstance(job_result.result, asyncio.CancelledError) @@ -364,18 +340,13 @@ async def test_job_abort( # the phase parameter and the POST form of the delete support. r = await client.post( "/test/jobs", - headers={"X-Auth-Request-User": "user"}, data={"runid": "some-run-id", "name": "Jane", "PHAse": "RUN"}, ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/3" - await runner.mark_in_progress("user", "3") - job = await job_service.get("user", "3") - r = await client.post( - "/test/jobs/3", - headers={"X-Auth-Request-User": "user"}, - data={"action": "DELETE"}, - ) + await runner.mark_in_progress(test_token, "3") + job = await job_service.get(test_token, "3") + r = await client.post("/test/jobs/3", data={"action": "DELETE"}) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs" assert job.message_id @@ -387,25 +358,23 @@ async def test_job_abort( @pytest.mark.asyncio async def test_job_api( client: AsyncClient, + test_token: str, + test_username: str, uws_factory: UWSFactory, ) -> None: job_service = uws_factory.create_job_service() # Create the job. r = await client.post( - "/test/jobs", - headers={"X-Auth-Request-User": "user"}, - data={"runid": "some-run-id", "name": "Jane"}, + "/test/jobs", data={"runid": "some-run-id", "name": "Jane"} ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" - job = await job_service.get("user", "1") + job = await job_service.get(test_token, "1") # Check the retrieval of the job configuration. destruction_time = job.creation_time + timedelta(seconds=24 * 60 * 60) - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" assert_job_summary_equal( @@ -421,46 +390,33 @@ async def test_job_api( ) # Check retrieving each part separately. - r = await client.get( - "/test/jobs/1/destruction", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/destruction") assert r.status_code == 200 assert r.headers["Content-Type"] == "text/plain; charset=utf-8" assert r.text == isodatetime(destruction_time) - r = await client.get( - "/test/jobs/1/executionduration", - headers={"X-Auth-Request-User": "user"}, - ) + r = await client.get("/test/jobs/1/executionduration") assert r.status_code == 200 assert r.headers["Content-Type"] == "text/plain; charset=utf-8" assert r.text == "600" - r = await client.get( - "/test/jobs/1/owner", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/owner") assert r.status_code == 200 assert r.headers["Content-Type"] == "text/plain; charset=utf-8" - assert r.text == "user" + assert r.text == test_username - r = await client.get( - "/test/jobs/1/parameters", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/parameters") assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" expected = SimpleXmlParameters.from_xml(JOB_PARAMETERS) assert SimpleXmlParameters.from_xml(r.text) == expected - r = await client.get( - "/test/jobs/1/phase", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/phase") assert r.status_code == 200 assert r.headers["Content-Type"] == "text/plain; charset=utf-8" assert r.text == "PENDING" - r = await client.get( - "/test/jobs/1/quote", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/quote") assert r.status_code == 200 assert r.headers["Content-Type"] == "text/plain; charset=utf-8" assert r.text == "" @@ -468,25 +424,19 @@ async def test_job_api( # Modify various settings. Validators will be tested elsewhere. now = current_datetime() r = await client.post( - "/test/jobs/1/destruction", - headers={"X-Auth-Request-User": "user"}, - data={"DESTRUCTION": isodatetime(now)}, + "/test/jobs/1/destruction", data={"DESTRUCTION": isodatetime(now)} ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" r = await client.post( - "/test/jobs/1/executionduration", - headers={"X-Auth-Request-User": "user"}, - data={"ExecutionDuration": 300}, + "/test/jobs/1/executionduration", data={"ExecutionDuration": 300} ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" # Retrieve the modified job and check that the new values are recorded. - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" assert_job_summary_equal( @@ -502,27 +452,19 @@ async def test_job_api( ) # Delete the job. - r = await client.delete( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.delete("/test/jobs/1") assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs" - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 404 # Create a new job and then delete it via POST. r = await client.post( - "/test/jobs", - headers={"X-Auth-Request-User": "user"}, - data={"name": "Jane", "RUNID": "some-run-id"}, + "/test/jobs", data={"name": "Jane", "RUNID": "some-run-id"} ) assert r.status_code == 303 - job = await job_service.get("user", "2") - r = await client.get( - "/test/jobs/2", headers={"X-Auth-Request-User": "user"} - ) + job = await job_service.get(test_token, "2") + r = await client.get("/test/jobs/2") assert r.status_code == 200 assert_job_summary_equal( JobSummary[SimpleXmlParameters], @@ -535,23 +477,16 @@ async def test_job_api( isodatetime(job.destruction_time), ), ) - r = await client.post( - "/test/jobs/2", - headers={"X-Auth-Request-User": "user"}, - data={"ACTION": "DELETE"}, - ) + r = await client.post("/test/jobs/2", data={"ACTION": "DELETE"}) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs" - r = await client.get( - "/test/jobs/2", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/2") assert r.status_code == 404 @pytest.mark.asyncio async def test_redirects( - app: FastAPI, - uws_factory: UWSFactory, + app: FastAPI, test_token: str, uws_factory: UWSFactory ) -> None: """Test the scheme in the redirect URLs. @@ -562,20 +497,19 @@ async def test_redirects( """ job_service = uws_factory.create_job_service() await job_service.create( - "user", - run_id="some-run-id", - params=[UWSJobParameter(parameter_id="name", value="Peter")], + test_token, SimpleParameters(name="Peter"), run_id="some-run-id" ) # Try various actions that result in redirects and ensure the redirect is # correct. async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://foo.com/" + transport=ASGITransport(app=app), + base_url="http://foo.com/", + headers={"X-Auth-Request-Token": test_token}, ) as client: r = await client.post( "/test/jobs/1/destruction", headers={ - "X-Auth-Request-User": "user", "Host": "example.com", "X-Forwarded-For": "10.10.10.10", "X-Forwarded-Proto": "https", @@ -589,7 +523,6 @@ async def test_redirects( r = await client.post( "/test/jobs/1/executionduration", headers={ - "X-Auth-Request-User": "user", "Host": "example.com", "X-Forwarded-For": "10.10.10.10", "X-Forwarded-Proto": "https", @@ -603,7 +536,6 @@ async def test_redirects( r = await client.delete( "/test/jobs/1", headers={ - "X-Auth-Request-User": "user", "Host": "example.com", "X-Forwarded-For": "10.10.10.10", "X-Forwarded-Proto": "https", @@ -617,32 +549,29 @@ async def test_redirects( @pytest.mark.asyncio async def test_presigned_url( client: AsyncClient, + test_token: str, runner: MockUWSJobRunner, uws_factory: UWSFactory, uws_config: UWSConfig, ) -> None: r = await client.post( - "/test/jobs?phase=RUN", - headers={"X-Auth-Request-User": "user"}, - data={"runid": "some-run-id", "name": "Jane"}, + "/test/jobs?phase=RUN", data={"runid": "some-run-id", "name": "Jane"} ) assert r.status_code == 303 - await runner.mark_in_progress("user", "1") + await runner.mark_in_progress(test_token, "1") # Tell the queue the job is finished, with an https URL. results = [ - UWSJobResult( + WorkerResult( result_id="cutout", url="https://example.com/some/path", mime_type="application/fits", ) ] - job = await runner.mark_complete("user", "1", results) + job = await runner.mark_complete(test_token, "1", results) # Check the job results, which should pass that URL through unchanged. - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert job.start_time assert job.end_time @@ -659,19 +588,25 @@ async def test_presigned_url( ) -def validate_destruction(destruction: datetime, job: UWSJob) -> datetime: +def validate_destruction(destruction: datetime, job: Job) -> datetime: max_destruction = current_datetime() + timedelta(days=1) return min(destruction, max_destruction) -def validate_execution_duration(duration: timedelta, job: UWSJob) -> timedelta: +def validate_execution_duration( + duration: timedelta | None, job: Job +) -> timedelta | None: max_duration = timedelta(seconds=200) - return min(duration, max_duration) + if not duration: + return max_duration + else: + return min(duration, max_duration) @pytest.mark.asyncio async def test_validators( client: AsyncClient, + test_token: str, arq_queue: MockArqQueue, uws_factory: UWSFactory, uws_config: UWSConfig, @@ -679,37 +614,29 @@ async def test_validators( uws_config.validate_destruction = validate_destruction uws_config.validate_execution_duration = validate_execution_duration job_service = uws_factory.create_job_service() - await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Tiffany")] - ) + await job_service.create(test_token, SimpleParameters(name="Tiffany")) # Change the destruction time, first to something that should be honored # and then something that should be overridden. destruction = current_datetime() + timedelta(hours=1) r = await client.post( "/test/jobs/1/destruction", - headers={"X-Auth-Request-User": "user"}, data={"desTRUcTiON": isodatetime(destruction)}, ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" - r = await client.get( - "/test/jobs/1/destruction", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/destruction") assert r.status_code == 200 assert r.text == isodatetime(destruction) destruction = current_datetime() + timedelta(days=5) expected = current_datetime() + timedelta(days=1) r = await client.post( "/test/jobs/1/destruction", - headers={"X-Auth-Request-User": "user"}, data={"destruction": isodatetime(destruction)}, ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" - r = await client.get( - "/test/jobs/1/destruction", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/destruction") assert r.status_code == 200 seen = datetime.fromisoformat(r.text[:-1] + "+00:00") assert seen >= expected - timedelta(seconds=5) @@ -717,28 +644,18 @@ async def test_validators( # Now do the same thing for execution duration. r = await client.post( - "/test/jobs/1/executionduration", - headers={"X-Auth-Request-User": "user"}, - data={"exECUTionduRATION": 100}, + "/test/jobs/1/executionduration", data={"exECUTionduRATION": 100} ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" - r = await client.get( - "/test/jobs/1/executionduration", - headers={"X-Auth-Request-User": "user"}, - ) + r = await client.get("/test/jobs/1/executionduration") assert r.status_code == 200 assert r.text == "100" r = await client.post( - "/test/jobs/1/executionduration", - headers={"X-Auth-Request-User": "user"}, - data={"exECUTionduRATION": 250}, + "/test/jobs/1/executionduration", data={"exECUTionduRATION": 250} ) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/1" - r = await client.get( - "/test/jobs/1/executionduration", - headers={"X-Auth-Request-User": "user"}, - ) + r = await client.get("/test/jobs/1/executionduration") assert r.status_code == 200 assert r.text == "200" diff --git a/safir/tests/uws/job_error_test.py b/safir/tests/uws/job_error_test.py index 8fabb8f3..5254754d 100644 --- a/safir/tests/uws/job_error_test.py +++ b/safir/tests/uws/job_error_test.py @@ -10,11 +10,10 @@ from safir.datetime import isodatetime from safir.testing.slack import MockSlackWebhook from safir.testing.uws import MockUWSJobRunner, assert_job_summary_equal -from safir.uws import UWSJobParameter from safir.uws._dependencies import UWSFactory from safir.uws._exceptions import TaskError -from ..support.uws import SimpleXmlParameters +from ..support.uws import SimpleParameters, SimpleXmlParameters ERRORED_JOB = """ 1 - user + test-user ERROR {} {} @@ -54,39 +53,30 @@ @pytest.mark.asyncio async def test_temporary_error( client: AsyncClient, + test_token: str, runner: MockUWSJobRunner, uws_factory: UWSFactory, mock_slack: MockSlackWebhook, ) -> None: job_service = uws_factory.create_job_service() - await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Sarah")] - ) + await job_service.create(test_token, SimpleParameters(name="Sarah")) # The pending job has no error. - r = await client.get( - "/test/jobs/1/error", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/error") assert r.status_code == 404 # Execute the job. - r = await client.post( - "/test/jobs/1/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "RUN"}, - ) + r = await client.post("/test/jobs/1/phase", data={"PHASE": "RUN"}) assert r.status_code == 303 - await runner.mark_in_progress("user", "1") + await runner.mark_in_progress(test_token, "1") exc = WorkerTransientError("Something failed") result = TaskError.from_worker_error(exc) - job = await runner.mark_complete("user", "1", result) + job = await runner.mark_complete(test_token, "1", result) # Check the results. assert job.start_time assert job.end_time - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert_job_summary_equal( JobSummary[SimpleXmlParameters], @@ -103,9 +93,7 @@ async def test_temporary_error( ) # Retrieve the error separately. - r = await client.get( - "/test/jobs/1/error", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/error") assert r.status_code == 200 assert r.text == JOB_ERROR_SUMMARY.strip().format( "ServiceUnavailable: Something failed" @@ -118,33 +106,26 @@ async def test_temporary_error( @pytest.mark.asyncio async def test_fatal_error( client: AsyncClient, + test_token: str, runner: MockUWSJobRunner, uws_factory: UWSFactory, mock_slack: MockSlackWebhook, ) -> None: job_service = uws_factory.create_job_service() - await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Sarah")] - ) + await job_service.create(test_token, SimpleParameters(name="Sarah")) # Start the job. - r = await client.post( - "/test/jobs/1/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "RUN"}, - ) + r = await client.post("/test/jobs/1/phase", data={"PHASE": "RUN"}) assert r.status_code == 303 - await runner.mark_in_progress("user", "1") + await runner.mark_in_progress(test_token, "1") exc = WorkerFatalError("Whoops", "Some details") result = TaskError.from_worker_error(exc) - job = await runner.mark_complete("user", "1", result) + job = await runner.mark_complete(test_token, "1", result) # Check the results. assert job.start_time assert job.end_time - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert_job_summary_equal( JobSummary[SimpleXmlParameters], @@ -161,9 +142,7 @@ async def test_fatal_error( ) # Retrieve the error separately. - r = await client.get( - "/test/jobs/1/error", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/error") assert r.status_code == 200 assert r.text == JOB_ERROR_SUMMARY.strip().format( "Error: Whoops\n\nSome details" @@ -176,32 +155,25 @@ async def test_fatal_error( @pytest.mark.asyncio async def test_unknown_error( client: AsyncClient, + test_token: str, runner: MockUWSJobRunner, uws_factory: UWSFactory, mock_slack: MockSlackWebhook, ) -> None: job_service = uws_factory.create_job_service() - await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Sarah")] - ) + await job_service.create(test_token, SimpleParameters(name="Sarah")) # Start the job. - r = await client.post( - "/test/jobs/1/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "RUN"}, - ) + r = await client.post("/test/jobs/1/phase", data={"PHASE": "RUN"}) assert r.status_code == 303 - await runner.mark_in_progress("user", "1") + await runner.mark_in_progress(test_token, "1") result = ValueError("Unknown exception") - job = await runner.mark_complete("user", "1", result) + job = await runner.mark_complete(test_token, "1", result) # Check the results. assert job.start_time assert job.end_time - r = await client.get( - "/test/jobs/1", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1") assert r.status_code == 200 assert_job_summary_equal( JobSummary[SimpleXmlParameters], @@ -218,9 +190,7 @@ async def test_unknown_error( ) # Retrieve the error separately. - r = await client.get( - "/test/jobs/1/error", headers={"X-Auth-Request-User": "user"} - ) + r = await client.get("/test/jobs/1/error") assert r.status_code == 200 assert r.text == JOB_ERROR_SUMMARY.strip().format( "Error: Unknown error executing task\n\n" diff --git a/safir/tests/uws/job_list_test.py b/safir/tests/uws/job_list_test.py index 811580a8..2c2107e6 100644 --- a/safir/tests/uws/job_list_test.py +++ b/safir/tests/uws/job_list_test.py @@ -10,14 +10,13 @@ import pytest from httpx import AsyncClient -from sqlalchemy import update from vo_models.uws import Jobs -from safir.database import datetime_to_db from safir.datetime import current_datetime, isodatetime -from safir.uws import UWSJobParameter +from safir.testing.uws import MockWobbly from safir.uws._dependencies import UWSFactory -from safir.uws._schema import Job as SQLJob + +from ..support.uws import SimpleParameters FULL_JOB_LIST = """ PENDING - user + test-user {} PENDING some-run-id - user + test-user {} PENDING - user + test-user {} @@ -56,7 +55,7 @@ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> PENDING - user + test-user {} @@ -73,7 +72,7 @@ QUEUED some-run-id - user + test-user {} @@ -81,48 +80,40 @@ @pytest.mark.asyncio -async def test_job_list(client: AsyncClient, uws_factory: UWSFactory) -> None: +async def test_job_list( + client: AsyncClient, + test_token: str, + test_service: str, + test_username: str, + uws_factory: UWSFactory, + mock_wobbly: MockWobbly, +) -> None: job_service = uws_factory.create_job_service() - jobs = [ - await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Joe")] - ), - await job_service.create( - "user", - run_id="some-run-id", - params=[UWSJobParameter(parameter_id="name", value="Catherine")], - ), - await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Pat")] - ), - ] + await job_service.create(test_token, SimpleParameters(name="Joe")) + await job_service.create( + test_token, + SimpleParameters(name="Catherine"), + run_id="some-run-id", + ) + await job_service.create(test_token, SimpleParameters(name="Pat")) # Create an additional job for a different user, which shouldn't appear in # any of the lists. - await job_service.create( - "otheruser", - params=[UWSJobParameter(parameter_id="name", value="Dominique")], - ) + other_token = MockWobbly.make_token(test_service, "other-user") + await job_service.create(other_token, SimpleParameters(name="Dominique")) # Adjust the creation time of the jobs so that searches are more # interesting. - async with uws_factory.session.begin(): - for i, job in enumerate(jobs): - hours = (2 - i) * 2 - creation = current_datetime() - timedelta(hours=hours) - stmt = ( - update(SQLJob) - .where(SQLJob.id == int(job.job_id)) - .values(creation_time=datetime_to_db(creation)) - ) - await uws_factory.session.execute(stmt) - job.creation_time = creation + jobs = mock_wobbly.jobs[test_service][test_username] + for i, job in enumerate(jobs.values()): + hours = (2 - i) * 2 + job.creation_time = current_datetime() - timedelta(hours=hours) # Retrieve the job list and check it. - r = await client.get("/test/jobs", headers={"X-Auth-Request-User": "user"}) + r = await client.get("/test/jobs") assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" - creation_times = [isodatetime(j.creation_time) for j in jobs] + creation_times = [isodatetime(j.creation_time) for j in jobs.values()] creation_times.reverse() expected = FULL_JOB_LIST.strip().format(*creation_times) assert Jobs.from_xml(r.text) == Jobs.from_xml(expected) @@ -130,9 +121,7 @@ async def test_job_list(client: AsyncClient, uws_factory: UWSFactory) -> None: # Filter by recency. threshold = current_datetime() - timedelta(hours=1) r = await client.get( - "/test/jobs", - headers={"X-Auth-Request-User": "user"}, - params={"after": isodatetime(threshold)}, + "/test/jobs", params={"after": isodatetime(threshold)} ) assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" @@ -142,42 +131,29 @@ async def test_job_list(client: AsyncClient, uws_factory: UWSFactory) -> None: # Check case-insensitivity. result = r.text r = await client.get( - "/test/jobs", - headers={"X-Auth-Request-User": "user"}, - params={"AFTER": isodatetime(threshold)}, + "/test/jobs", params={"AFTER": isodatetime(threshold)} ) assert r.text == result r = await client.get( - "/test/jobs", - headers={"X-Auth-Request-User": "user"}, - params={"aFTer": isodatetime(threshold)}, + "/test/jobs", params={"aFTer": isodatetime(threshold)} ) assert r.text == result # Filter by count. - r = await client.get( - "/test/jobs", - headers={"X-Auth-Request-User": "user"}, - params={"last": 1}, - ) + r = await client.get("/test/jobs", params={"last": 1}) assert r.status_code == 200 assert r.headers["Content-Type"] == "application/xml" expected = RECENT_JOB_LIST.strip().format(creation_times[0]) assert Jobs.from_xml(r.text) == Jobs.from_xml(expected) # Start the job. - r = await client.post( - "/test/jobs/2/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "RUN"}, - ) + r = await client.post("/test/jobs/2/phase", data={"PHASE": "RUN"}) assert r.status_code == 303 assert r.headers["Location"] == "https://example.com/test/jobs/2" # Filter by phase. r = await client.get( "/test/jobs", - headers={"X-Auth-Request-User": "user"}, params=[("PHASE", "EXECUTING"), ("PHASE", "QUEUED")], ) assert r.status_code == 200 diff --git a/safir/tests/uws/long_polling_test.py b/safir/tests/uws/long_polling_test.py index ec0bc147..4c631f54 100644 --- a/safir/tests/uws/long_polling_test.py +++ b/safir/tests/uws/long_polling_test.py @@ -9,13 +9,12 @@ from httpx import AsyncClient from vo_models.uws import JobSummary +from safir.arq.uws import WorkerResult from safir.datetime import current_datetime, isodatetime from safir.testing.uws import MockUWSJobRunner, assert_job_summary_equal -from safir.uws import UWSJobParameter from safir.uws._dependencies import UWSFactory -from safir.uws._models import UWSJobResult -from ..support.uws import SimpleXmlParameters +from ..support.uws import SimpleParameters, SimpleXmlParameters PENDING_JOB = """ 1 - user + test-user {} {} 600 @@ -46,7 +45,7 @@ xmlns:xlink="http://www.w3.org/1999/xlink" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> 1 - user + test-user EXECUTING {} {} @@ -67,7 +66,7 @@ xmlns:xlink="http://www.w3.org/1999/xlink" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> 1 - user + test-user COMPLETED {} {} @@ -87,22 +86,18 @@ @pytest.mark.asyncio async def test_poll( - client: AsyncClient, runner: MockUWSJobRunner, uws_factory: UWSFactory + client: AsyncClient, + test_token: str, + runner: MockUWSJobRunner, + uws_factory: UWSFactory, ) -> None: job_service = uws_factory.create_job_service() - job = await job_service.create( - "user", - params=[UWSJobParameter(parameter_id="name", value="Naomi")], - ) + job = await job_service.create(test_token, SimpleParameters(name="Naomi")) # Poll for changes for one second. Nothing will happen since nothing is # changing the mock arq queue. now = current_datetime() - r = await client.get( - "/test/jobs/1", - headers={"X-Auth-Request-User": "user"}, - params={"WAIT": "1"}, - ) + r = await client.get("/test/jobs/1", params={"WAIT": "1"}) assert (current_datetime() - now).total_seconds() >= 1 assert r.status_code == 200 assert_job_summary_equal( @@ -117,10 +112,7 @@ async def test_poll( # Start the job and worker. r = await client.post( - "/test/jobs/1/phase", - headers={"X-Auth-Request-User": "user"}, - data={"PHASE": "RUN"}, - follow_redirects=True, + "/test/jobs/1/phase", data={"PHASE": "RUN"}, follow_redirects=True ) assert r.status_code == 200 assert r.url == "https://example.com/test/jobs/1" @@ -137,12 +129,8 @@ async def test_poll( # Poll for a change from queued, which we should see after half a second. now = current_datetime() job, r = await asyncio.gather( - runner.mark_in_progress("user", "1", delay=0.5), - client.get( - "/test/jobs/1", - headers={"X-Auth-Request-User": "user"}, - params={"WAIT": "2", "phase": "QUEUED"}, - ), + runner.mark_in_progress(test_token, "1", delay=0.5), + client.get("/test/jobs/1", params={"WAIT": "2", "phase": "QUEUED"}), ) assert r.status_code == 200 assert job.start_time @@ -159,19 +147,15 @@ async def test_poll( # Now, wait again, in parallel with the job finishing. We should get a # reply after a second and a half when the job finishes. results = [ - UWSJobResult( + WorkerResult( result_id="cutout", url="s3://some-bucket/some/path", mime_type="application/fits", ) ] job, r = await asyncio.gather( - runner.mark_complete("user", "1", results, delay=1.5), - client.get( - "/test/jobs/1", - headers={"X-Auth-Request-User": "user"}, - params={"WAIT": "2", "phase": "EXECUTING"}, - ), + runner.mark_complete(test_token, "1", results, delay=1.5), + client.get("/test/jobs/1", params={"WAIT": "2", "phase": "EXECUTING"}), ) assert r.status_code == 200 assert job.start_time diff --git a/safir/tests/uws/post_params_test.py b/safir/tests/uws/post_params_test.py deleted file mode 100644 index ccf97d1c..00000000 --- a/safir/tests/uws/post_params_test.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Tests for sync cutout requests.""" - -from __future__ import annotations - -import pytest -from httpx import AsyncClient - - -@pytest.mark.asyncio -async def test_post_params_multiple_params(client: AsyncClient) -> None: - """Test that the post dependency correctly handles multiple - occurences of the same parameter. - """ - params = [ - {"parameter_id": "id", "value": "image1"}, - {"parameter_id": "id", "value": "image2"}, - {"parameter_id": "pos", "value": "RANGE 10 20 30 40"}, - {"parameter_id": "pos", "value": "CIRCLE 10 20 5"}, - ] - - response = await client.post("/test/params", json=params) - assert response.status_code == 200 - assert response.json() == { - "params": [ - {"id": "id", "value": "image1"}, - {"id": "id", "value": "image2"}, - {"id": "pos", "value": "RANGE 10 20 30 40"}, - {"id": "pos", "value": "CIRCLE 10 20 5"}, - ] - } diff --git a/safir/tests/uws/workers_test.py b/safir/tests/uws/workers_test.py index 29b0e8ba..dad81801 100644 --- a/safir/tests/uws/workers_test.py +++ b/safir/tests/uws/workers_test.py @@ -27,10 +27,9 @@ ) from safir.datetime import current_datetime from safir.testing.slack import MockSlackWebhook -from safir.uws import UWSApplication, UWSConfig, UWSJobParameter +from safir.uws import JobResult, UWSApplication, UWSConfig from safir.uws._constants import UWS_DATABASE_TIMEOUT from safir.uws._dependencies import UWSFactory -from safir.uws._models import ErrorCode, UWSJobResult from safir.uws._storage import JobStore from ..support.uws import SimpleParameters @@ -50,7 +49,10 @@ def hello( @pytest.mark.asyncio async def test_build_worker( - uws_config: UWSConfig, logger: BoundLogger + uws_config: UWSConfig, + test_token: str, + test_username: str, + logger: BoundLogger, ) -> None: redis_settings = uws_config.arq_redis_settings worker_config = WorkerConfig( @@ -90,8 +92,8 @@ async def test_build_worker( params = SimpleParameters(name="Roger") info = WorkerJobInfo( job_id="42", - user="someuser", - token="some-token", + user=test_username, + token=test_token, timeout=timedelta(minutes=1), run_id="some-run-id", ) @@ -103,7 +105,7 @@ async def test_build_worker( JobMetadata( id=ANY, name="uws_job_started", - args=("42", ANY), + args=(test_token, "42", ANY), kwargs={}, enqueue_time=ANY, status=JobStatus.queued, @@ -112,7 +114,7 @@ async def test_build_worker( JobMetadata( id=ANY, name="uws_job_completed", - args=("42",), + args=(test_token, "42"), kwargs={}, enqueue_time=ANY, status=JobStatus.queued, @@ -126,7 +128,12 @@ async def test_build_worker( @pytest.mark.asyncio -async def test_timeout(uws_config: UWSConfig, logger: BoundLogger) -> None: +async def test_timeout( + uws_config: UWSConfig, + test_token: str, + test_username: str, + logger: BoundLogger, +) -> None: redis_settings = uws_config.arq_redis_settings worker_config = WorkerConfig( arq_mode=uws_config.arq_mode, @@ -154,8 +161,8 @@ async def test_timeout(uws_config: UWSConfig, logger: BoundLogger) -> None: params = SimpleParameters(name="Timeout") info = WorkerJobInfo( job_id="42", - user="someuser", - token="some-token", + user=test_username, + token=test_token, timeout=timedelta(seconds=1), run_id="some-run-id", ) @@ -165,7 +172,7 @@ async def test_timeout(uws_config: UWSConfig, logger: BoundLogger) -> None: JobMetadata( id=ANY, name="uws_job_started", - args=("42", ANY), + args=(test_token, "42", ANY), kwargs={}, enqueue_time=ANY, status=JobStatus.queued, @@ -174,7 +181,7 @@ async def test_timeout(uws_config: UWSConfig, logger: BoundLogger) -> None: JobMetadata( id=ANY, name="uws_job_completed", - args=("42",), + args=(test_token, "42"), kwargs={}, enqueue_time=ANY, status=JobStatus.queued, @@ -201,18 +208,18 @@ async def test_timeout(uws_config: UWSConfig, logger: BoundLogger) -> None: async def test_build_uws_worker( arq_queue: MockArqQueue, uws_config: UWSConfig, + test_token: str, + test_username: str, uws_factory: UWSFactory, mock_slack: MockSlackWebhook, logger: BoundLogger, ) -> None: uws = UWSApplication(uws_config) job_service = uws_factory.create_job_service() - job = await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Ahmed")] - ) - results = [UWSJobResult(result_id="greeting", url="https://example.com")] - await job_service.start("user", job.job_id, "some-token") - job = await job_service.get("user", job.job_id) + job = await job_service.create(test_token, SimpleParameters(name="Ahmed")) + results = [WorkerResult(result_id="greeting", url="https://example.com")] + await job_service.start(test_token, test_username, job.id) + job = await job_service.get(test_token, job.id) assert job.start_time is None assert job.phase == ExecutionPhase.QUEUED @@ -223,12 +230,6 @@ async def test_build_uws_worker( assert callable(job_started) job_completed = settings.functions[1] assert callable(job_completed) - assert settings.cron_jobs - assert len(settings.cron_jobs) == 1 - expire_cron = settings.cron_jobs[0] - assert expire_cron.unique - expire_jobs = expire_cron.coroutine - assert callable(expire_jobs) assert settings.redis_settings == uws_config.arq_redis_settings assert not settings.allow_abort_jobs assert settings.job_completion_wait == UWS_DATABASE_TIMEOUT @@ -250,39 +251,25 @@ async def test_build_uws_worker( now = current_datetime() assert job.message_id await arq_queue.set_in_progress(job.message_id) - await job_started(ctx, job.job_id, now) - job = await job_service.get("user", job.job_id) + await job_started(ctx, test_token, job.id, now) + job = await job_service.get(test_token, job.id) assert job.phase == ExecutionPhase.EXECUTING assert job.start_time == now # Test finishing a job. assert job.message_id await asyncio.gather( - job_completed(ctx, job.job_id), + job_completed(ctx, test_token, job.id), arq_queue.set_complete(job.message_id, result=results), ) - job = await job_service.get("user", job.job_id) + job = await job_service.get(test_token, job.id) assert job.phase == ExecutionPhase.COMPLETED assert job.end_time assert job.end_time.microsecond == 0 assert now <= job.end_time <= current_datetime() - assert job.results == results + assert job.results == [JobResult.from_worker_result(r) for r in results] assert mock_slack.messages == [] - # Expiring jobs should do nothing since the destruction time of our one - # job has not passed. - jobs = await job_service.list_jobs("user", "https://example.com") - await expire_jobs(ctx) - assert await job_service.list_jobs("user", "https://example.com") == jobs - - # Change the destruction date of the job and then it should be expired. - past = current_datetime() - timedelta(minutes=5) - expires = await job_service.update_destruction("user", job.job_id, past) - assert expires == past - await expire_jobs(ctx) - jobs = await job_service.list_jobs("user", "https://example.com") - assert not jobs.jobref - def nonnegative(value: int) -> None: if value < 0: raise ValueError("Value not nonnegative") @@ -296,35 +283,34 @@ def make_exception() -> None: ) from e # Test starting and erroring a job with a TaskError. - job = await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Ahmed")] - ) - await job_service.start("user", job.job_id, "some-token") - job = await job_service.get("user", job.job_id) + job = await job_service.create(test_token, SimpleParameters(name="Ahmed")) + await job_service.start(test_token, test_username, job.id) + job = await job_service.get(test_token, job.id) assert job.message_id await arq_queue.set_in_progress(job.message_id) - await job_started(ctx, job.job_id, now) + await job_started(ctx, test_token, job.id, now) try: make_exception() except WorkerFatalError as e: error = e await asyncio.gather( - job_completed(ctx, job.job_id), + job_completed(ctx, test_token, job.id), arq_queue.set_complete(job.message_id, result=error, success=False), ) - job = await job_service.get("user", job.job_id) + job = await job_service.get(test_token, job.id) assert job.phase == ExecutionPhase.ERROR assert job.end_time assert job.end_time.microsecond == 0 assert now <= job.end_time <= current_datetime() - assert job.error - assert job.error.error_type == ErrorType.FATAL - assert job.error.error_code == ErrorCode.ERROR - assert job.error.message == "Something" - assert job.error.detail - assert "went wrong" in job.error.detail + assert job.errors + assert len(job.errors) == 1 + assert job.errors[0].type == ErrorType.FATAL + assert job.errors[0].code == "Error" + assert job.errors[0].message == "Something" + assert job.errors[0].detail + assert "went wrong" in job.errors[0].detail assert error.traceback - assert error.traceback in job.error.detail + assert error.traceback in job.errors[0].detail assert mock_slack.messages == [ { "blocks": [ @@ -351,7 +337,7 @@ def make_exception() -> None: {"text": ANY, "type": "mrkdwn", "verbatim": True}, {"text": ANY, "type": "mrkdwn", "verbatim": True}, { - "text": "*User*\nuser", + "text": f"*User*\n{test_username}", "type": "mrkdwn", "verbatim": True, }, @@ -396,26 +382,25 @@ def make_exception() -> None: # Test starting and erroring a job with an unknown exception. mock_slack.messages = [] - job = await job_service.create( - "user", params=[UWSJobParameter(parameter_id="name", value="Ahmed")] - ) - await job_service.start("user", job.job_id, "some-token") - job = await job_service.get("user", job.job_id) + job = await job_service.create(test_token, SimpleParameters(name="Ahmed")) + await job_service.start(test_token, test_username, job.id) + job = await job_service.get(test_token, job.id) assert job.message_id await arq_queue.set_in_progress(job.message_id) - await job_started(ctx, job.job_id, now) + await job_started(ctx, test_token, job.id, now) exc = ValueError("some error") await asyncio.gather( - job_completed(ctx, job.job_id), + job_completed(ctx, test_token, job.id), arq_queue.set_complete(job.message_id, result=exc, success=False), ) - job = await job_service.get("user", job.job_id) + job = await job_service.get(test_token, job.id) assert job.phase == ExecutionPhase.ERROR - assert job.error - assert job.error.error_type == ErrorType.FATAL - assert job.error.error_code == ErrorCode.ERROR - assert job.error.message == "Unknown error executing task" - assert job.error.detail == "ValueError: some error" + assert job.errors + assert len(job.errors) == 1 + assert job.errors[0].type == ErrorType.FATAL + assert job.errors[0].code == "Error" + assert job.errors[0].message == "Unknown error executing task" + assert job.errors[0].detail == "ValueError: some error" assert mock_slack.messages == [ { "blocks": [