diff --git a/.github/workflows/test-instance.yml b/.github/workflows/test-instance.yml index 7e78929fcf..f1ac869d04 100644 --- a/.github/workflows/test-instance.yml +++ b/.github/workflows/test-instance.yml @@ -3,7 +3,7 @@ name: "Spawn Test Instance" on: [push, pull_request] jobs: - pytest: + test_instance: runs-on: ubuntu-latest strategy: diff --git a/superdesk/celery_app/serializer.py b/superdesk/celery_app/serializer.py index 268ea4d7e7..507018eb9e 100644 --- a/superdesk/celery_app/serializer.py +++ b/superdesk/celery_app/serializer.py @@ -7,7 +7,7 @@ from kombu.serialization import register from superdesk.core import json -from superdesk.core.web.types import WSGIApp +from superdesk.core.types import WSGIApp CELERY_SERIALIZER_NAME = "context-aware/json" diff --git a/superdesk/core/__init__.py b/superdesk/core/__init__.py index 9b9b378b34..9c7321d7af 100644 --- a/superdesk/core/__init__.py +++ b/superdesk/core/__init__.py @@ -9,7 +9,7 @@ # at https://www.sourcefabric.org/superdesk/license from quart import json -from .app import get_app_config, get_current_app, get_current_async_app +from .app import get_app_config, get_current_app, get_current_async_app, get_current_auth __all__ = [ @@ -17,4 +17,5 @@ "get_current_async_app", "json", "get_app_config", + "get_current_auth", ] diff --git a/superdesk/core/app.py b/superdesk/core/app.py index 8ac444fb2f..6d7d617422 100644 --- a/superdesk/core/app.py +++ b/superdesk/core/app.py @@ -11,7 +11,8 @@ from typing import Dict, List, Optional, Any, cast import importlib -from .web import WSGIApp +from superdesk.core.types import WSGIApp +from .auth.user_auth import UserAuthProtocol def get_app_config(key: str, default: Optional[Any] = None) -> Optional[Any]: @@ -44,6 +45,8 @@ class SuperdeskAsyncApp: resources: "Resources" + auth: UserAuthProtocol + def __init__(self, wsgi: WSGIApp): self._running = False self._imported_modules = {} @@ -52,6 +55,7 @@ def __init__(self, wsgi: WSGIApp): self.mongo = MongoResources(self) self.elastic = ElasticResources(self) self.resources = Resources(self) + self.auth = self.load_auth_module() self._store_app() @property @@ -65,6 +69,23 @@ def get_module_list(self) -> List["Module"]: return sorted(self._imported_modules.values(), key=lambda x: x.priority, reverse=True) + def load_auth_module(self) -> UserAuthProtocol: + auth_module_config = cast( + str, self.wsgi.config.get("ASYNC_AUTH_CLASS", "superdesk.core.auth.token_auth:TokenAuthorization") + ) + try: + module_path, module_attribute = auth_module_config.split(":") + except ValueError as error: + raise RuntimeError(f"Invalid config ASYNC_AUTH_MODULE={auth_module_config}: {error}") + + imported_module = importlib.import_module(module_path) + auth_class = getattr(imported_module, module_attribute) + + if not issubclass(auth_class, UserAuthProtocol): + raise RuntimeError(f"Invalid config ASYNC_AUTH_MODULE={auth_module_config}, invalid auth type {auth_class}") + + return auth_class() + def _load_modules(self, paths: List[str | tuple[str, dict]]): for path in paths: config: dict = {} @@ -208,6 +229,10 @@ def get_current_async_app() -> SuperdeskAsyncApp: raise RuntimeError("Superdesk app is not running") +def get_current_auth() -> UserAuthProtocol: + return get_current_async_app().auth + + _global_app: Optional[SuperdeskAsyncApp] = None diff --git a/superdesk/core/auth/__init__.py b/superdesk/core/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/superdesk/core/auth/rules.py b/superdesk/core/auth/rules.py new file mode 100644 index 0000000000..307eb95418 --- /dev/null +++ b/superdesk/core/auth/rules.py @@ -0,0 +1,24 @@ +from typing import Any + +from superdesk.core.types import Request +from superdesk.errors import SuperdeskApiError + + +async def login_required_auth_rule(request: Request) -> None: + if request.user is None: + raise SuperdeskApiError.unauthorizedError() + + return None + + +async def endpoint_intrinsic_auth_rule(request: Request) -> Any | None: + methods = ["authorize", f"authorize_{request.method.lower()}"] + for method_name in methods: + intrinsic_auth = getattr(request.endpoint, method_name, None) + + if intrinsic_auth: + response = await intrinsic_auth(request) + if response is not None: + return response + + return None diff --git a/superdesk/core/auth/token_auth.py b/superdesk/core/auth/token_auth.py new file mode 100644 index 0000000000..536ae3e758 --- /dev/null +++ b/superdesk/core/auth/token_auth.py @@ -0,0 +1,88 @@ +from typing import Any, cast +from datetime import timedelta + +from superdesk.core import get_app_config +from superdesk.core.types import Request +from superdesk import get_resource_service +from superdesk.errors import SuperdeskApiError +from superdesk.resource_fields import LAST_UPDATED, ID_FIELD +from superdesk.utc import utcnow + +from .user_auth import UserAuthProtocol + + +class TokenAuthorization(UserAuthProtocol): + async def authenticate(self, request: Request): + token = request.get_header("Authorization") + new_session = True + if token: + token = token.strip() + if token.lower().startswith(("token", "bearer")): + token = token.split(" ")[1] if " " in token else "" + else: + token = request.storage.session.get("session_token") + new_session = False + + if not token: + await self.stop_session(request) + raise SuperdeskApiError.unauthorizedError() + + # Check provided token is valid + auth_service = get_resource_service("auth") + auth_token = auth_service.find_one(token=token, req=None) + + if not auth_token: + await self.stop_session(request) + raise SuperdeskApiError.unauthorizedError() + + user_service = get_resource_service("users") + user_id = str(auth_token["user"]) + user = user_service.find_one(req=None, _id=user_id) + + if not user: + await self.stop_session(request) + raise SuperdeskApiError.unauthorizedError() + + if new_session: + await self.start_session(request, user, auth_token=auth_token) + else: + await self.continue_session(request, user) + + async def start_session(self, request: Request, user: dict[str, Any], **kwargs) -> None: + auth_token: str | None = kwargs.pop("auth_token", None) + if not auth_token: + await self.stop_session(request) + raise SuperdeskApiError.unauthorizedError() + + request.storage.session.set("session_token", auth_token) + await super().start_session(request, user, **kwargs) + + async def continue_session(self, request: Request, user: dict[str, Any], **kwargs) -> None: + auth_token = request.storage.session.get("session_token") + + if not auth_token: + await self.stop_session(request) + raise SuperdeskApiError.unauthorizedError() + + user_service = get_resource_service("users") + request.storage.request.set("user", user) + request.storage.request.set("role", user_service.get_role(user)) + request.storage.request.set("auth", auth_token) + request.storage.request.set("auth_value", auth_token["user"]) + + if request.method in ("POST", "PUT", "PATCH") or (request.method == "GET" and not request.get_url_arg("auto")): + now = utcnow() + auth_updated = False + session_update_seconds = cast(int, get_app_config("SESSION_UPDATE_SECONDS", 30)) + if auth_token[LAST_UPDATED] + timedelta(seconds=session_update_seconds) < now: + auth_service = get_resource_service("auth") + auth_service.update_session({LAST_UPDATED: now}) + auth_updated = True + if auth_updated or not request.storage.request.get("last_activity_at"): + user_service.system_update(user[ID_FIELD], {"last_activity_at": now, "_updated": now}, user) + + await super().continue_session(request, user, **kwargs) + + def get_current_user(self, request: Request) -> dict[str, Any] | None: + user = request.storage.request.get("user") + return user diff --git a/superdesk/core/auth/user_auth.py b/superdesk/core/auth/user_auth.py new file mode 100644 index 0000000000..4907df5a4a --- /dev/null +++ b/superdesk/core/auth/user_auth.py @@ -0,0 +1,44 @@ +from typing import Any, cast + +from superdesk.errors import SuperdeskApiError +from superdesk.core.types import Request, AuthRule + + +class UserAuthProtocol: + async def authenticate(self, request: Request) -> Any | None: + raise SuperdeskApiError.unauthorizedError() + + async def authorize(self, request: Request) -> Any | None: + endpoint_rules = request.endpoint.get_auth_rules() + if endpoint_rules is False: + # This is a public facing endpoint + # Meaning Authentication & Authorization is disabled + return None + elif isinstance(endpoint_rules, dict): + endpoint_rules = cast(list[AuthRule], endpoint_rules.get(request.method) or []) + + from .rules import login_required_auth_rule, endpoint_intrinsic_auth_rule + + default_rules: list[AuthRule] = [ + login_required_auth_rule, + endpoint_intrinsic_auth_rule, + ] + + for rule in default_rules + (endpoint_rules or []): + response = await rule(request) + if response is not None: + return response + + return None + + async def start_session(self, request: Request, user: Any, **kwargs) -> None: + await self.continue_session(request, user, **kwargs) + + async def continue_session(self, request: Request, user: Any, **kwargs) -> None: + request.user = user + + async def stop_session(self, request: Request) -> None: + pass + + def get_current_user(self, request: Request) -> Any | None: + raise NotImplementedError() diff --git a/superdesk/core/module.py b/superdesk/core/module.py index 8e8c6269a9..b4baeb2d2c 100644 --- a/superdesk/core/module.py +++ b/superdesk/core/module.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from .config import ConfigModel -from .web import Endpoint, EndpointGroup +from superdesk.core.types import Endpoint, EndpointGroup @dataclass diff --git a/superdesk/core/resources/cursor.py b/superdesk/core/resources/cursor.py index 892a02c703..dc43fa7c48 100644 --- a/superdesk/core/resources/cursor.py +++ b/superdesk/core/resources/cursor.py @@ -12,8 +12,10 @@ from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorCursor +from .model import ResourceModel -ResourceModelType = TypeVar("ResourceModelType", bound="ResourceModel") + +ResourceModelType = TypeVar("ResourceModelType", bound=ResourceModel) class ResourceCursorAsync(Generic[ResourceModelType]): @@ -63,7 +65,9 @@ def get_model_instance(self, data: Dict[str, Any]): return self.data_class.from_dict(data) -class ElasticsearchResourceCursorAsync(ResourceCursorAsync): +class ElasticsearchResourceCursorAsync(ResourceCursorAsync[ResourceModelType], Generic[ResourceModelType]): + _index: int + hits: dict[str, Any] no_hits = {"hits": {"total": 0, "hits": []}} def __init__(self, data_class: Type[ResourceModelType], hits=None): @@ -113,7 +117,7 @@ def extra(self, response: Dict[str, Any]): response["_aggregations"] = self.hits["aggregations"] -class MongoResourceCursorAsync(ResourceCursorAsync): +class MongoResourceCursorAsync(ResourceCursorAsync[ResourceModelType], Generic[ResourceModelType]): def __init__( self, data_class: Type[ResourceModelType], @@ -137,6 +141,3 @@ async def next_raw(self) -> Optional[Dict[str, Any]]: async def count(self): return await self.collection.count_documents(self.lookup) - - -from .model import ResourceModel # noqa: E402 diff --git a/superdesk/core/resources/model.py b/superdesk/core/resources/model.py index 8236414f93..9c8ba5b6df 100644 --- a/superdesk/core/resources/model.py +++ b/superdesk/core/resources/model.py @@ -188,7 +188,15 @@ async def _run_async_validators_from_model_class( if field_name_stack is None: field_name_stack = [] - annotations = get_annotations(model_instance.__class__) + model_class = model_instance.__class__ + + try: + annotations = {} + for base_class in reversed(model_class.__mro__): + if base_class != ResourceModel and issubclass(base_class, ResourceModel): + annotations.update(get_annotations(base_class)) + except (TypeError, AttributeError): + annotations = get_annotations(model_class) for field_name, annotation in annotations.items(): value = getattr(model_instance, field_name) diff --git a/superdesk/core/resources/resource_rest_endpoints.py b/superdesk/core/resources/resource_rest_endpoints.py index ec1c93fb58..0fd7a30aad 100644 --- a/superdesk/core/resources/resource_rest_endpoints.py +++ b/superdesk/core/resources/resource_rest_endpoints.py @@ -20,12 +20,20 @@ from superdesk.core import json from superdesk.core.app import get_current_async_app -from superdesk.core.types import SearchRequest, SearchArgs, VersionParam +from superdesk.core.types import ( + SearchRequest, + SearchArgs, + VersionParam, + AuthConfig, + HTTP_METHOD, + Request, + Response, + RestGetResponse, +) from superdesk.errors import SuperdeskApiError from superdesk.resource_fields import STATUS, STATUS_OK, ITEMS -from ..web.types import HTTP_METHOD, Request, Response, RestGetResponse -from ..web.rest_endpoints import RestEndpoints, ItemRequestViewArgs +from superdesk.core.web import RestEndpoints, ItemRequestViewArgs from .model import ResourceConfig, ResourceModel from .validators import convert_pydantic_validation_error_for_response @@ -77,6 +85,8 @@ class RestEndpointConfig: #: This will prepend this resources URL with the URL of the parent resource item parent_links: list[RestParentLink] | None = None + auth: AuthConfig = None + def get_id_url_type(data_class: type[ResourceModel]) -> str: """Get the URL param type for the ID field for route registration""" @@ -116,6 +126,7 @@ def __init__( resource_methods=endpoint_config.resource_methods, item_methods=endpoint_config.item_methods, id_param_type=endpoint_config.id_param_type or get_id_url_type(resource_config.data_class), + auth=endpoint_config.auth, ) def get_resource_url(self) -> str: diff --git a/superdesk/core/resources/service.py b/superdesk/core/resources/service.py index 3da6919bf1..cb3754ed76 100644 --- a/superdesk/core/resources/service.py +++ b/superdesk/core/resources/service.py @@ -22,6 +22,7 @@ Union, cast, overload, + Type, ) import logging import ast @@ -85,10 +86,16 @@ def id_uses_objectid(self) -> bool: def mongo(self): """Return instance of MongoCollection for this resource""" + return self.app.mongo.get_collection(self.resource_name) + + @property + def mongo_async(self): + """Return instance of async MongoCollection for this resource""" + return self.app.mongo.get_collection_async(self.resource_name) @property - def mongo_versioned(self): + def mongo_versioned_async(self): return self.app.mongo.get_collection_async(self.resource_name, True) @property @@ -128,7 +135,7 @@ async def find_one_raw(self, use_mongo: bool = False, version: int | None = None pass if use_mongo or item is None: - item = await self.mongo.find_one(lookup) + item = await self.mongo_async.find_one(lookup) if item is None: return None @@ -175,7 +182,7 @@ async def find_by_id_raw( try: item = await self.elastic.find_by_id(item_id) except KeyError: - item = await self.mongo.find_one({"_id": item_id}) + item = await self.mongo_async.find_one({"_id": item_id}) if item is None: return None @@ -195,12 +202,14 @@ async def search(self, lookup: Dict[str, Any], use_mongo=False) -> ResourceCurso try: if not use_mongo: response = await self.elastic.search(lookup) - return ElasticsearchResourceCursorAsync(self.config.data_class, response) + return ElasticsearchResourceCursorAsync(cast(Type[ResourceModelType], self.config.data_class), response) except KeyError: pass - response = self.mongo.find(lookup) - return MongoResourceCursorAsync(self.config.data_class, self.mongo, response, lookup) + response = self.mongo_async.find(lookup) + return MongoResourceCursorAsync( + cast(Type[ResourceModelType], self.config.data_class), self.mongo_async, response, lookup + ) async def on_create(self, docs: List[ResourceModelType]) -> None: """Hook to run before creating new resource(s) @@ -283,7 +292,7 @@ async def create(self, _docs: Sequence[ResourceModelType | dict[str, Any]]) -> L context={"use_objectid": True} if not self.config.query_objectid_as_string else {}, ) doc.etag = doc_dict["_etag"] = self.generate_etag(doc_dict, self.config.etag_ignore_fields) - response = await self.mongo.insert_one(doc_dict) + response = await self.mongo_async.insert_one(doc_dict) ids.append(response.inserted_id) try: await self.elastic.insert([doc_dict]) @@ -291,7 +300,7 @@ async def create(self, _docs: Sequence[ResourceModelType | dict[str, Any]]) -> L pass if self.config.versioning: - await self.mongo_versioned.insert_one(self._get_versioned_document(doc_dict)) + await self.mongo_versioned_async.insert_one(self._get_versioned_document(doc_dict)) await self.on_created(docs) return ids @@ -354,14 +363,14 @@ async def update(self, item_id: Union[str, ObjectId], updates: Dict[str, Any], e if model_has_versions(original): updates.pop("_latest_version", None) updates_dict.pop("_latest_version", None) - response = await self.mongo.update_one({"_id": item_id}, {"$set": updates_dict}) + response = await self.mongo_async.update_one({"_id": item_id}, {"$set": updates_dict}) try: await self.elastic.update(item_id, updates_dict) except KeyError: pass if self.config.versioning: - await self.mongo_versioned.insert_one(self._get_versioned_document(validated_updates)) + await self.mongo_versioned_async.insert_one(self._get_versioned_document(validated_updates)) await self.on_updated(updates, original) @@ -391,7 +400,7 @@ async def delete(self, doc: ResourceModelType, etag: str | None = None): await self.on_delete(doc) self.validate_etag(doc, etag) - await self.mongo.delete_one({"_id": doc.id}) + await self.mongo_async.delete_one({"_id": doc.id}) try: await self.elastic.remove(doc.id) except KeyError: @@ -405,14 +414,14 @@ async def delete_many(self, lookup: Dict[str, Any]) -> List[str]: :return: List of IDs for the deleted resources """ - docs_to_delete = self.mongo.find(lookup).sort("_id", 1) + docs_to_delete = self.mongo_async.find(lookup).sort("_id", 1) ids: List[str] = [] async for data in docs_to_delete: doc = self.get_model_instance_from_dict(data) await self.on_delete(doc) ids.append(str(doc.id)) - await self.mongo.delete_one({"_id": doc.id}) + await self.mongo_async.delete_one({"_id": doc.id}) try: await self.elastic.remove(doc.id) @@ -441,7 +450,7 @@ async def get_all(self) -> AsyncIterable[ResourceModelType]: yield doc def get_all_raw(self) -> AsyncIOMotorCursor: - return self.mongo.find({}).sort("_id") + return self.mongo_async.find({}).sort("_id") async def get_all_batch(self, size=500, max_iterations=10000, lookup=None) -> AsyncIterable[ResourceModelType]: """Helper function to get all items from this resource, in batches @@ -460,7 +469,7 @@ async def get_all_batch(self, size=500, max_iterations=10000, lookup=None) -> As if last_id is not None: _lookup.update({"_id": {"$gt": last_id}}) - cursor = self.mongo.find(_lookup).sort("_id").limit(size) + cursor = self.mongo_async.find(_lookup).sort("_id").limit(size) last_doc = None async for data in cursor: last_doc = data @@ -530,13 +539,17 @@ async def find( try: if not use_mongo: cursor, count = await self.elastic.find(search_request) - return ElasticsearchResourceCursorAsync(self.config.data_class, cursor.hits) + return ElasticsearchResourceCursorAsync( + cast(Type[ResourceModelType], self.config.data_class), cursor.hits + ) except KeyError: pass return await self._mongo_find(search_request) - async def _mongo_find(self, req: SearchRequest, versioned: bool = False) -> MongoResourceCursorAsync: + async def _mongo_find( + self, req: SearchRequest, versioned: bool = False + ) -> MongoResourceCursorAsync[ResourceModelType]: kwargs: Dict[str, Any] = {} if req.max_results: @@ -558,10 +571,13 @@ async def _mongo_find(self, req: SearchRequest, versioned: bool = False) -> Mong projection_fields if projection_include else {field: False for field in projection_fields} ) - cursor = self.mongo.find(**kwargs) if not versioned else self.mongo_versioned.find(**kwargs) + cursor = self.mongo_async.find(**kwargs) if not versioned else self.mongo_versioned_async.find(**kwargs) return MongoResourceCursorAsync( - self.config.data_class, self.mongo if not versioned else self.mongo_versioned, cursor, where + cast(Type[ResourceModelType], self.config.data_class), + self.mongo_async if not versioned else self.mongo_versioned_async, + cursor, + where, ) def _convert_req_to_mongo_sort(self, sort: SortParam | None) -> SortListParam: @@ -654,7 +670,7 @@ async def get_all_item_versions( if not self.config.versioning: raise SuperdeskApiError.badRequestError("Resource does not support versioning") - item: dict | None = await self.mongo.find_one({ID_FIELD: item_id}) + item: dict | None = await self.mongo_async.find_one({ID_FIELD: item_id}) if not item: raise SuperdeskApiError.notFoundError() @@ -677,7 +693,7 @@ async def get_item_version(self, item: dict, version: int) -> dict: if not self.config.versioning: raise SuperdeskApiError.badRequestError("Resource does not support versioning") - versioned_item: dict | None = await self.mongo_versioned.find_one( + versioned_item: dict | None = await self.mongo_versioned_async.find_one( { VERSION_ID_FIELD: item[ID_FIELD], CURRENT_VERSION: version, @@ -699,6 +715,13 @@ def convert_versioned_item_for_response(self, item: dict, versioned_item: dict): if self.config.ignore_fields_in_versions: versioned_item.update({key: item[key] for key in self.config.ignore_fields_in_versions if item.get(key)}) + async def system_update(self, item_id: ObjectId | str, updates: dict[str, Any]) -> None: + await self.mongo_async.update_one({"_id": item_id}, {"$set": updates}) + try: + await self.elastic.update(item_id, updates) + except KeyError: + pass + class AsyncCacheableService(AsyncResourceService[ResourceModelType]): """ diff --git a/superdesk/core/resources/validators.py b/superdesk/core/resources/validators.py index 8bf8a1dfa3..6cf737d986 100644 --- a/superdesk/core/resources/validators.py +++ b/superdesk/core/resources/validators.py @@ -24,8 +24,11 @@ EmailValueType = str | list[str] | None -def validate_email() -> AfterValidator: - """Validates that the value is a valid email address""" +def validate_email(error_string: str | None = None) -> AfterValidator: + """Validates that the value is a valid email address + + :param error_string: An optional custom error string if validation fails + """ def _validate_email(value: EmailValueType) -> EmailValueType: if value is None: @@ -41,7 +44,7 @@ def _validate_email(value: EmailValueType) -> EmailValueType: # given that admins are usually create users, not users by themself, # probably just check for @ is enough # https://davidcel.is/posts/stop-validating-email-addresses-with-regex/ - raise PydanticCustomError("email", gettext("Invalid email address")) + raise PydanticCustomError("email", str(error_string) if error_string else gettext("Invalid email address")) return value return AfterValidator(_validate_email) @@ -50,11 +53,14 @@ def _validate_email(value: EmailValueType) -> EmailValueType: MinMaxValueType = str | int | float | list[str] | list[int] | list[float] | None -def validate_minlength(min_length: int, validate_list_elements: bool = False) -> AfterValidator: +def validate_minlength( + min_length: int, validate_list_elements: bool = False, error_string: str | None = None +) -> AfterValidator: """Validates that the value has a minimum length :param min_length: The minimum length of the value :param validate_list_elements: Whether to validate the elements in the list or the list length + :param error_string: An optional custom error string if validation fails """ def _validate_minlength(value: MinMaxValueType) -> MinMaxValueType: @@ -63,20 +69,23 @@ def _validate_minlength(value: MinMaxValueType) -> MinMaxValueType: _validate_minlength(val) elif isinstance(value, (type(""), list)): if len(value) < min_length: - raise PydanticCustomError("minlength", gettext("Not enough")) + raise PydanticCustomError("minlength", str(error_string) if error_string else gettext("Not enough")) elif isinstance(value, (int, float)): if value < min_length: - raise PydanticCustomError("min_length", gettext("Too short")) + raise PydanticCustomError("min_length", str(error_string) if error_string else gettext("Too short")) return value return AfterValidator(_validate_minlength) -def validate_maxlength(max_length: int, validate_list_elements: bool = False) -> AfterValidator: +def validate_maxlength( + max_length: int, validate_list_elements: bool = False, error_string: str | None = None +) -> AfterValidator: """Validates that the value has a maximum length (strings or arrays) :param max_length: The maximum length of the value :param validate_list_elements: Whether to validate the elements in the list or the list length + :param error_string: An optional custom error string if validation fails """ def _validate_maxlength(value: MinMaxValueType) -> MinMaxValueType: @@ -85,10 +94,10 @@ def _validate_maxlength(value: MinMaxValueType) -> MinMaxValueType: _validate_maxlength(val) elif isinstance(value, (type(""), list)): if len(value) > max_length: - raise PydanticCustomError("maxlength", gettext("Too many")) + raise PydanticCustomError("maxlength", str(error_string) if error_string else gettext("Too many")) elif isinstance(value, (int, float)): if value > max_length: - raise PydanticCustomError("maxlength", gettext("Too short")) + raise PydanticCustomError("maxlength", str(error_string) if error_string else gettext("Too short")) return value return AfterValidator(_validate_maxlength) @@ -105,13 +114,14 @@ def __init__(self, func: Callable[["ResourceModel", Any], Awaitable[None]]): def validate_data_relation_async( - resource_name: str, external_field: str = "_id", convert_to_objectid: bool = False + resource_name: str, external_field: str = "_id", convert_to_objectid: bool = False, error_string: str | None = None ) -> AsyncValidator: """Validate the ID on the resource points to an existing resource :param resource_name: The name of the resource type the ID points to :param external_field: The field used to find the resource :param convert_to_objectid: If True, will convert the ID to an ObjectId instance + :param error_string: An optional custom error string if validation fails """ async def validate_resource_exists(item: ResourceModel, item_id: DataRelationValueType) -> None: @@ -133,7 +143,9 @@ async def validate_resource_exists(item: ResourceModel, item_id: DataRelationVal if not await collection.find_one({external_field: item_id}): raise PydanticCustomError( "data_relation", - gettext("Resource '{resource_name}' with ID '{item_id}' does not exist"), + str(error_string) + if error_string + else gettext("Resource '{resource_name}' with ID '{item_id}' does not exist"), dict( resource_name=resource_name, item_id=item_id, @@ -162,11 +174,12 @@ async def validate_resource_exists(item: ResourceModel, item_id: DataRelationVal UniqueValueType = str | list[str] | None -def validate_unique_value_async(resource_name: str, field_name: str) -> AsyncValidator: +def validate_unique_value_async(resource_name: str, field_name: str, error_string: str | None = None) -> AsyncValidator: """Validate that the field is unique in the resource (case-sensitive) :param resource_name: The name of the resource where the field must be unique :param field_name: The name of the field where the field must be unique + :param error_string: An optional custom error string if validation fails """ async def validate_unique_value_in_resource(item: ResourceModel, name: UniqueValueType) -> None: @@ -181,16 +194,19 @@ async def validate_unique_value_in_resource(item: ResourceModel, name: UniqueVal query = {"_id": {"$ne": item.id}, field_name: {"$in": name} if isinstance(name, list) else name} if await collection.find_one(query): - raise PydanticCustomError("unique", gettext("Value must be unique")) + raise PydanticCustomError("unique", str(error_string) if error_string else gettext("Value must be unique")) return AsyncValidator(validate_unique_value_in_resource) -def validate_iunique_value_async(resource_name: str, field_name: str) -> AsyncValidator: +def validate_iunique_value_async( + resource_name: str, field_name: str, error_string: str | None = None +) -> AsyncValidator: """Validate that the field is unique in the resource (case-insensitive) :param resource_name: The name of the resource where the field must be unique :param field_name: The name of the field where the field must be unique + :param error_string: An optional custom error string if validation fails """ async def validate_iunique_value_in_resource(item: ResourceModel, name: UniqueValueType) -> None: @@ -213,7 +229,7 @@ async def validate_iunique_value_in_resource(item: ResourceModel, name: UniqueVa } if await collection.find_one(query): - raise PydanticCustomError("unique", gettext("Value must be unique")) + raise PydanticCustomError("unique", str(error_string) if error_string else gettext("Value must be unique")) return AsyncValidator(validate_iunique_value_in_resource) diff --git a/superdesk/core/types.py b/superdesk/core/types.py index 8b8433dd76..2a25881ba4 100644 --- a/superdesk/core/types.py +++ b/superdesk/core/types.py @@ -8,12 +8,30 @@ # AUTHORS and LICENSE files distributed with this source code, or # at https://www.sourcefabric.org/superdesk/license -from typing import Dict, Any, Optional, List, Union, Literal +from typing import ( + Dict, + Any, + Optional, + List, + Literal, + Sequence, + Union, + Callable, + Awaitable, + TypeVar, + Protocol, + Mapping, + NoReturn, +) from typing_extensions import TypedDict +from dataclasses import dataclass from pydantic import BaseModel, ConfigDict, NonNegativeInt, field_validator -from . import json +DefaultNoValue = object() + + +HTTP_METHOD = Literal["GET", "POST", "PATCH", "PUT", "DELETE", "HEAD", "OPTIONS"] #: The data type for projections, either a list of field names, or a dictionary containing @@ -95,8 +113,384 @@ class SearchRequest(BaseModel): @field_validator("projection", mode="before") def parse_projection(cls, value: ProjectedFieldArg | str | None) -> ProjectedFieldArg | None: + from superdesk.core import json + if not value: return None elif isinstance(value, str): return json.loads(value) return value + + +@dataclass +class Response: + """Dataclass for endpoints to return response from a request""" + + #: The body of the response (Flask will determine data type for us) + body: Any + + #: HTTP Status Code of the response + status_code: int = 200 + + #: Any additional headers to be added + headers: Sequence = () + + +PydanticModelType = TypeVar("PydanticModelType", bound=BaseModel) + + +#: Function for use with a Endpoint registration and request processing +#: +#: Supported endpoint signatures:: +#: +#: # Response only +#: async def test() -> Response: +#: +#: # Request Only +#: async def test1(request: Request) -> Response +#: +#: # Args and Request +#: async def test2( +#: args: Pydantic.BaseModel, +#: params: None, +#: request: Request +#: ) -> Response +#: +#: # Params and Request +#: async def test3( +#: args: None, +#: params: Pydantic.BaseModel, +#: request: Request +#: ) -> Response +#: +#: # Args, Params and Request +#: async def test4( +#: args: Pydantic.BaseModel, +#: params: Pydantic.BaseModel, +#: request: Request +#: ) -> Response +EndpointFunction = Union[ + Callable[ + [], + Awaitable[Response], + ], + Callable[ + ["Request"], + Awaitable[Response], + ], + Callable[ + [PydanticModelType, PydanticModelType, "Request"], + Awaitable[Response], + ], + Callable[ + [None, PydanticModelType, "Request"], + Awaitable[Response], + ], + Callable[ + [PydanticModelType, None, "Request"], + Awaitable[Response], + ], + Callable[ + [None, None, "Request"], + Awaitable[Response], + ], + Callable[ + [], + Response, + ], + Callable[ + ["Request"], + Response, + ], + Callable[ + [PydanticModelType, PydanticModelType, "Request"], + Response, + ], + Callable[ + [None, PydanticModelType, "Request"], + Response, + ], + Callable[ + [PydanticModelType, None, "Request"], + Response, + ], + Callable[ + [None, None, "Request"], + Response, + ], +] + + +class RequestStorageProvider(Protocol): + def get(self, key: str, default: Any | None = DefaultNoValue) -> Any: + ... + + def set(self, key: str, value: Any) -> None: + ... + + def pop(self, key: str, default: Any | None = DefaultNoValue) -> Any: + ... + + +class RequestSessionStorageProvider(RequestStorageProvider): + def set_session_permanent(self, value: bool) -> None: + ... + + def is_session_permanent(self) -> bool: + return False + + def clear(self) -> None: + ... + + +class RequestStorage(Protocol): + session: RequestSessionStorageProvider + request: RequestStorageProvider + + +class Request(Protocol): + """Protocol to define common request functionality + + This is implemented in `SuperdeskEve` app using Flask to provide the required functionality. + """ + + #: The current Endpoint being processed + endpoint: "Endpoint" + storage: RequestStorage + user: Any | None + + @property + def method(self) -> HTTP_METHOD: + """Returns the current HTTP method for the request""" + ... + + @property + def url(self) -> str: + """Returns the URL of the current request""" + ... + + @property + def path(self) -> str: + """Returns the URL of the current request""" + ... + + def get_header(self, key: str) -> Optional[str]: + """Get an HTTP header from the current request""" + ... + + async def get_json(self) -> Union[Any, None]: + """Get the body of the current request in JSON format""" + ... + + async def get_form(self) -> Mapping: + """Get the body of the current request in form format""" + ... + + async def get_data(self) -> Union[bytes, str]: + """Get the body of the current request in raw bytes format""" + ... + + async def abort(self, code: int, *args: Any, **kwargs: Any) -> NoReturn: + ... + + def get_view_args(self, key: str) -> str | None: + ... + + def get_url_arg(self, key: str) -> str | None: + ... + + def redirect(self, location: str, code: int = 302) -> Any: + ... + + +AuthRule = Callable[[Request], Awaitable[Any | None]] +AuthConfig = Literal[False] | list[AuthRule] | dict[str, AuthRule] | None + + +class Endpoint: + """Base class used for registering and processing endpoints""" + + #: URL for the endpoint + url: str + + #: Name of the endpoint (must be unique) + name: str + + #: HTTP Methods allowed for this endpoint + methods: List[HTTP_METHOD] + + #: The callback function used to process the request + func: EndpointFunction + + auth: AuthConfig + + def __init__( + self, + url: str, + func: EndpointFunction, + methods: list[HTTP_METHOD] | None = None, + name: str | None = None, + auth: AuthConfig = None, + parent: Union["EndpointGroup", None] = None, + ): + self.url = url + self.func = func + self.methods = methods or ["GET"] + self.name = name or func.__name__ + self.auth = auth + self.parent = parent + + async def __call__(self, args: dict[str, Any], params: dict[str, Any], request: Request): + ... + + def get_auth_rules(self) -> AuthConfig: + if self.auth is None: + return self.parent.get_auth_rules() if self.parent is not None else None + return self.auth + + +class EndpointGroup(Endpoint): + """Base class used for registering a group of endpoints""" + + # #: Name for this endpoint group. Will be prepended to each endpoint name + # name: str + + #: The import name of the module where this object is defined. + #: Usually :attr:`__name__` should be used. + import_name: str + + #: Optional url prefix to be added to all routes of this group + url_prefix: Optional[str] + + #: List of endpoints registered with this group + endpoints: List[Endpoint] + + def endpoint( + self, + url: str, + name: str | None = None, + methods: list[HTTP_METHOD] | None = None, + auth: AuthConfig = None, + ): + ... + + +class NotificationClientProtocol(Protocol): + open: bool + messages: Sequence[str] + + def close(self) -> None: + ... + + def send(self, message: str) -> None: + ... + + def reset(self) -> None: + ... + + +class WSGIApp(Protocol): + """Protocol for defining functionality from a WSGI application (such as Eve/Flask) + + A class instance that adheres to this protocol is passed into the SuperdeskAsyncApp constructor. + This way the SuperdeskAsyncApp does not need to know the underlying WSGI application, just that + it provides certain functionality. + """ + + #: Config for the application + config: dict[str, Any] + + #: Config for the front-end application + client_config: dict[str, Any] + + testing: bool = False + + #: Interface to upload/download/query media + media: Any + + mail: Any + + data: Any + + storage: Any + + auth: Any + + subjects: Any + + notification_client: NotificationClientProtocol + + locators: Any + + celery: Any + + redis: Any + + jinja_loader: Any + + jinja_env: Any + + extensions: Dict[str, Any] + + def register_endpoint(self, endpoint: Endpoint | EndpointGroup): + ... + + def register_resource(self, name: str, settings: dict[str, Any]): + ... + + def upload_url(self, media_id: str) -> str: + ... + + def download_url(self, media_id: str) -> str: + ... + + # TODO: Provide proper type here, context manager + def app_context(self): + ... + + def get_current_user_dict(self) -> dict[str, Any] | None: + ... + + def response_class(self, *args, **kwargs) -> Any: + ... + + def validator(self, *args, **kwargs) -> Any: + ... + + def init_indexes(self, ignore_duplicate_keys: bool = False) -> None: + ... + + def as_any(self) -> Any: + ... + + def get_current_request(self) -> Request | None: + ... + + def get_endpoint_for_current_request(self) -> Endpoint | None: + ... + + +class RestResponseMeta(TypedDict): + """Dictionary to hold the response metadata for a REST request""" + + #: Current page requested + page: int + + #: Maximum results requested + max_results: int + + #: Total number of documents found + total: int + + +class RestGetResponse(TypedDict, total=False): + """Dictionary to hold the response for a REST request""" + + #: The list of documents found for the search request + _items: list[dict[str, Any]] + + #: HATEOAS links + _links: dict[str, Any] + + #: Response metadata + _meta: RestResponseMeta diff --git a/superdesk/core/web/__init__.py b/superdesk/core/web/__init__.py index c474ec7a27..ddcb9aff80 100644 --- a/superdesk/core/web/__init__.py +++ b/superdesk/core/web/__init__.py @@ -8,15 +8,14 @@ # AUTHORS and LICENSE files distributed with this source code, or # at https://www.sourcefabric.org/superdesk/license -from .types import ( - HTTP_METHOD, - Response, - EndpointFunction, - Endpoint, - Request, - EndpointGroup, - RestResponseMeta, - RestGetResponse, - endpoint, - WSGIApp, -) +from .endpoints import Endpoint, EndpointGroup, NullEndpoint, endpoint +from .rest_endpoints import RestEndpoints, ItemRequestViewArgs + +__all__ = [ + "Endpoint", + "EndpointGroup", + "NullEndpoint", + "endpoint", + "RestEndpoints", + "ItemRequestViewArgs", +] diff --git a/superdesk/core/web/endpoints.py b/superdesk/core/web/endpoints.py new file mode 100644 index 0000000000..c508f0a18c --- /dev/null +++ b/superdesk/core/web/endpoints.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8; -*- +# +# This file is part of Superdesk. +# +# Copyright 2024 Sourcefabric z.u. and contributors. +# +# For the full copyright and license information, please see the +# AUTHORS and LICENSE files distributed with this source code, or +# at https://www.sourcefabric.org/superdesk/license + +from typing import Any, Awaitable +from inspect import signature, isawaitable +import logging + +from pydantic import BaseModel, ValidationError + +from superdesk.core.types import ( + Request, + Response, + EndpointFunction, + HTTP_METHOD, + Endpoint as EndpointProtocol, + AuthConfig, + EndpointGroup as EndpointGroupProtocol, +) + + +logger = logging.getLogger(__name__) + + +class Endpoint(EndpointProtocol): + """Base class used for registering and processing endpoints""" + + async def __call__(self, args: dict[str, Any], params: dict[str, Any], request: Request): + from superdesk.core.resources import ResourceModel + from superdesk.core import get_current_async_app + + # Implement Auth here + if request.endpoint.get_auth_rules() is not False: + async_app = get_current_async_app() + response = await async_app.auth.authenticate(request) + if response is not None: + logger.warning("Authenticate returned a non-None value") + return response + response = await async_app.auth.authorize(request) + if response is not None: + logger.warning("Authorize returned a non-None value") + return response + + response = self._run_endpoint_func(args, params, request) + if isawaitable(response): + response = await response + + if not isinstance(response, Response): + # TODO-ASYNC: Implement our own wrapper around Response for specific use cases + # like redirect or abort and handle these here + # We may have received a different response, such as a flask redirect call + # So we return it here + return response + elif isinstance(response.body, ResourceModel): + response.body = response.body.to_dict() + + return response + + def _run_endpoint_func( + self, args: dict[str, Any], params: dict[str, Any], request: Request + ) -> Awaitable[Response] | Response: + func_params = signature(self.func).parameters + if not len(func_params): + return self.func() # type: ignore[call-arg,arg-type] + elif "args" not in func_params and "params" not in func_params: + return self.func(request) # type: ignore[call-arg,arg-type] + + arg_type = func_params["args"] if "args" in func_params else None + request_args = None + if arg_type is not None and arg_type.annotation is not None and issubclass(arg_type.annotation, BaseModel): + request_args = arg_type.annotation.model_validate(args) + + param_type = func_params["params"] if "params" in func_params else None + url_params = None + if ( + param_type is not None + and param_type.annotation is not None + and issubclass(param_type.annotation, BaseModel) + ): + try: + url_params = param_type.annotation.model_validate(params) + except ValidationError as error: + from superdesk.core.resources.validators import get_field_errors_from_pydantic_validation_error + + errors = { + field: list(err.values())[0] + for field, err in get_field_errors_from_pydantic_validation_error(error).items() + } + return Response(errors, 400, ()) + + return self.func(request_args, url_params, request) # type: ignore[call-arg,arg-type] + + +def return_404() -> Response: + return Response("", 404) + + +class NullEndpointClass(Endpoint): + def __init__(self): + super().__init__(url="", func=return_404) + + +NullEndpoint = NullEndpointClass() + + +class EndpointGroup(EndpointGroupProtocol): + """Base class used for registering a group of endpoints""" + + def __init__( + self, + name: str, + import_name: str, + url_prefix: str | None = None, + auth: AuthConfig = None, + ): + super().__init__( + url="", + func=return_404, + auth=auth, + ) + self.name = name + self.import_name = import_name + self.url_prefix = url_prefix + self.endpoints = [] + + def endpoint( + self, + url: str, + name: str | None = None, + methods: list[HTTP_METHOD] | None = None, + auth: AuthConfig = None, + ): + """Decorator function to register an endpoint to this group + + :param url: The URL of the endpoint + :param name: The optional name of the endpoint + :param methods: The optional list of HTTP methods allowed + """ + + def fdec(func: EndpointFunction): + endpoint_func = Endpoint( + f"{self.url}/{url}" if self.url else url, + func, + methods=methods, + name=name, + auth=auth, + parent=self, + ) + self.endpoints.append(endpoint_func) + return endpoint_func + + return fdec + + def __call__(self, args: dict[str, Any], params: dict[str, Any], request: Request): + return return_404() + + +def endpoint(url: str, name: str | None = None, methods: list[HTTP_METHOD] | None = None, auth: AuthConfig = None): + """Decorator function to convert a pure function to an Endpoint instance + + This is then later used to register with a Module or the app. + + :param url: The URL of the endpoint + :param name: The optional name of the endpoint + :param methods: The optional list of HTTP methods allowed + """ + + def convert_to_endpoint(func: EndpointFunction): + return Endpoint( + url=url, + name=name, + methods=methods, + func=func, + auth=auth, + ) + + return convert_to_endpoint diff --git a/superdesk/core/web/rest_endpoints.py b/superdesk/core/web/rest_endpoints.py index be18c8e6af..d4edea3499 100644 --- a/superdesk/core/web/rest_endpoints.py +++ b/superdesk/core/web/rest_endpoints.py @@ -8,12 +8,12 @@ # AUTHORS and LICENSE files distributed with this source code, or # at https://www.sourcefabric.org/superdesk/license -from typing import List, Optional, Any +from typing import Any from pydantic import BaseModel -from superdesk.core.types import SearchRequest -from .types import Endpoint, EndpointGroup, HTTP_METHOD, Request, Response +from superdesk.core.types import SearchRequest, AuthConfig, HTTP_METHOD, Request, Response +from .endpoints import Endpoint, EndpointGroup class ItemRequestViewArgs(BaseModel): @@ -27,24 +27,27 @@ class RestEndpoints(EndpointGroup): url: str #: The list of HTTP methods for the resource endpoints - resource_methods: List[HTTP_METHOD] + resource_methods: list[HTTP_METHOD] #: The list of HTTP methods for the resource item endpoints - item_methods: List[HTTP_METHOD] + item_methods: list[HTTP_METHOD] #: Optionally set the route param type for the ID, defaults to ``string`` id_param_type: str + auth: AuthConfig = None + def __init__( self, url: str, name: str, - import_name: Optional[str] = None, - resource_methods: Optional[List[HTTP_METHOD]] = None, - item_methods: Optional[List[HTTP_METHOD]] = None, - id_param_type: Optional[str] = None, + import_name: str | None = None, + resource_methods: list[HTTP_METHOD] | None = None, + item_methods: list[HTTP_METHOD] | None = None, + id_param_type: str | None = None, + auth: AuthConfig = None, ): - super().__init__(name, import_name or __name__) + super().__init__(name, import_name or __name__, auth=auth) self.url = url self.resource_methods = resource_methods or ["GET", "POST"] self.item_methods = item_methods or ["GET", "PATCH", "DELETE"] @@ -58,6 +61,7 @@ def __init__( name="resource_get", func=self.search_items, methods=["GET"], + parent=self, ) ) @@ -68,6 +72,7 @@ def __init__( name="resource_post", func=self.create_item, methods=["POST"], + parent=self, ) ) @@ -79,6 +84,7 @@ def __init__( name="item_get", func=self.get_item, methods=["GET"], + parent=self, ) ) @@ -89,6 +95,7 @@ def __init__( name="item_patch", func=self.update_item, methods=["PATCH"], + parent=self, ) ) @@ -99,6 +106,7 @@ def __init__( name="item_delete", func=self.delete_item, methods=["DELETE"], + parent=self, ) ) diff --git a/superdesk/core/web/types.py b/superdesk/core/web/types.py deleted file mode 100644 index a15712f379..0000000000 --- a/superdesk/core/web/types.py +++ /dev/null @@ -1,399 +0,0 @@ -# -*- coding: utf-8; -*- -# -# This file is part of Superdesk. -# -# Copyright 2024 Sourcefabric z.u. and contributors. -# -# For the full copyright and license information, please see the -# AUTHORS and LICENSE files distributed with this source code, or -# at https://www.sourcefabric.org/superdesk/license - -from typing import ( - Any, - Protocol, - Sequence, - Optional, - Callable, - Awaitable, - Literal, - List, - TypeVar, - Mapping, - Union, - TypedDict, - Dict, - NoReturn, -) -from inspect import signature - -from dataclasses import dataclass -from pydantic import BaseModel, ValidationError - -HTTP_METHOD = Literal["GET", "POST", "PATCH", "PUT", "DELETE", "HEAD", "OPTIONS"] - - -PydanticModelType = TypeVar("PydanticModelType", bound=BaseModel) - - -@dataclass -class Response: - """Dataclass for endpoints to return response from a request""" - - #: The body of the response (Flask will determine data type for us) - body: Any - - #: HTTP Status Code of the response - status_code: int = 200 - - #: Any additional headers to be added - headers: Sequence = () - - -#: Function for use with a Endpoint registration and request processing -#: -#: Supported endpoint signatures:: -#: -#: # Response only -#: async def test() -> Response: -#: -#: # Request Only -#: async def test1(request: Request) -> Response -#: -#: # Args and Request -#: async def test2( -#: args: Pydantic.BaseModel, -#: params: None, -#: request: Request -#: ) -> Response -#: -#: # Params and Request -#: async def test3( -#: args: None, -#: params: Pydantic.BaseModel, -#: request: Request -#: ) -> Response -#: -#: # Args, Params and Request -#: async def test4( -#: args: Pydantic.BaseModel, -#: params: Pydantic.BaseModel, -#: request: Request -#: ) -> Response -EndpointFunction = Union[ - Callable[ - [], - Awaitable[Response], - ], - Callable[ - ["Request"], - Awaitable[Response], - ], - Callable[ - [PydanticModelType, PydanticModelType, "Request"], - Awaitable[Response], - ], - Callable[ - [None, PydanticModelType, "Request"], - Awaitable[Response], - ], - Callable[ - [PydanticModelType, None, "Request"], - Awaitable[Response], - ], - Callable[ - [None, None, "Request"], - Awaitable[Response], - ], -] - - -class Endpoint: - """Base class used for registering and processing endpoints""" - - #: URL for the endpoint - url: str - - #: Name of the endpoint (must be unique) - name: str - - #: HTTP Methods allowed for this endpoint - methods: List[HTTP_METHOD] - - #: The callback function used to process the request - func: EndpointFunction - - def __init__( - self, - url: str, - func: EndpointFunction, - methods: Optional[List[HTTP_METHOD]] = None, - name: Optional[str] = None, - ): - self.url = url - self.func = func - self.methods = methods or ["GET"] - self.name = name or func.__name__ - - async def __call__(self, args: Dict[str, Any], params: Dict[str, Any], request: "Request"): - func_params = signature(self.func).parameters - if not len(func_params): - return await self.func() # type: ignore[call-arg,arg-type] - elif "args" not in func_params and "params" not in func_params: - return await self.func(request) # type: ignore[call-arg,arg-type] - - arg_type = func_params["args"] if "args" in func_params else None - request_args = None - if arg_type is not None and arg_type.annotation is not None and issubclass(arg_type.annotation, BaseModel): - request_args = arg_type.annotation.model_validate(args) - - param_type = func_params["params"] if "params" in func_params else None - url_params = None - if ( - param_type is not None - and param_type.annotation is not None - and issubclass(param_type.annotation, BaseModel) - ): - try: - url_params = param_type.annotation.model_validate(params) - except ValidationError as error: - from superdesk.core.resources.validators import get_field_errors_from_pydantic_validation_error - - errors = { - field: list(err.values())[0] - for field, err in get_field_errors_from_pydantic_validation_error(error).items() - } - return Response(errors, 400, ()) - - return await self.func(request_args, url_params, request) # type: ignore[call-arg,arg-type] - - -class Request(Protocol): - """Protocol to define common request functionality - - This is implemented in `SuperdeskEve` app using Flask to provide the required functionality. - """ - - #: The current Endpoint being processed - endpoint: Endpoint - - @property - def method(self) -> HTTP_METHOD: - """Returns the current HTTP method for the request""" - ... - - @property - def path(self) -> str: - """Returns the URL of the current request""" - ... - - def get_header(self, key: str) -> Optional[str]: - """Get an HTTP header from the current request""" - ... - - async def get_json(self) -> Union[Any, None]: - """Get the body of the current request in JSON format""" - ... - - async def get_form(self) -> Mapping: - """Get the body of the current request in form format""" - ... - - async def get_data(self) -> Union[bytes, str]: - """Get the body of the current request in raw bytes format""" - ... - - async def abort(self, code: int, *args: Any, **kwargs: Any) -> NoReturn: - ... - - def get_view_args(self, key: str) -> str | None: - ... - - def get_url_arg(self, key: str) -> str | None: - ... - - -class EndpointGroup: - """Base class used for registering a group of endpoints""" - - #: Name for this endpoint group. Will be prepended to each endpoint name - name: str - - #: The import name of the module where this object is defined. - #: Usually :attr:`__name__` should be used. - import_name: str - - #: Optional url prefix to be added to all routes of this group - url_prefix: Optional[str] - - #: List of endpoints registered with this group - endpoints: List[Endpoint] - - def __init__(self, name: str, import_name: str, url_prefix: Optional[str] = None): - self.name = name - self.import_name = import_name - self.url_prefix = url_prefix - self.endpoints = [] - - def endpoint( - self, - url: str, - name: Optional[str] = None, - methods: Optional[List[HTTP_METHOD]] = None, - ): - """Decorator function to register an endpoint to this group - - :param url: The URL of the endpoint - :param name: The optional name of the endpoint - :param methods: The optional list of HTTP methods allowed - """ - - def fdec(func: EndpointFunction): - endpoint_func = Endpoint( - f"{self.url_prefix}/{url}" if self.url_prefix else url, - func, - methods=methods, - name=name, - ) - self.endpoints.append(endpoint_func) - return endpoint_func - - return fdec - - -class RestResponseMeta(TypedDict): - """Dictionary to hold the response metadata for a REST request""" - - #: Current page requested - page: int - - #: Maximum results requested - max_results: int - - #: Total number of documents found - total: int - - -class RestGetResponse(TypedDict, total=False): - """Dictionary to hold the response for a REST request""" - - #: The list of documents found for the search request - _items: List[Dict[str, Any]] - - #: HATEOAS links - _links: Dict[str, Any] - - #: Response metadata - _meta: RestResponseMeta - - -def endpoint(url: str, name: Optional[str] = None, methods: Optional[List[HTTP_METHOD]] = None): - """Decorator function to convert a pure function to an Endpoint instance - - This is then later used to register with a Module or the app. - - :param url: The URL of the endpoint - :param name: The optional name of the endpoint - :param methods: The optional list of HTTP methods allowed - """ - - def convert_to_endpoint(func: EndpointFunction): - return Endpoint( - url=url, - name=name, - methods=methods, - func=func, - ) - - return convert_to_endpoint - - -class NotificationClientProtocol(Protocol): - open: bool - messages: Sequence[str] - - def close(self) -> None: - ... - - def send(self, message: str) -> None: - ... - - def reset(self) -> None: - ... - - -class WSGIApp(Protocol): - """Protocol for defining functionality from a WSGI application (such as Eve/Flask) - - A class instance that adheres to this protocol is passed into the SuperdeskAsyncApp constructor. - This way the SuperdeskAsyncApp does not need to know the underlying WSGI application, just that - it provides certain functionality. - """ - - #: Config for the application - config: Dict[str, Any] - - #: Config for the front-end application - client_config: Dict[str, Any] - - testing: Optional[bool] - - #: Interface to upload/download/query media - media: Any - - mail: Any - - data: Any - - storage: Any - - auth: Any - - subjects: Any - - notification_client: NotificationClientProtocol - - locators: Any - - celery: Any - - redis: Any - - jinja_loader: Any - - jinja_env: Any - - extensions: Dict[str, Any] - - def register_endpoint(self, endpoint: Endpoint | EndpointGroup): - ... - - def register_resource(self, name: str, settings: Dict[str, Any]): - ... - - def upload_url(self, media_id: str) -> str: - ... - - def download_url(self, media_id: str) -> str: - ... - - # TODO: Provide proper type here, context manager - def app_context(self): - ... - - def get_current_user_dict(self) -> Optional[Dict[str, Any]]: - ... - - def response_class(self, *args, **kwargs) -> Any: - ... - - def validator(self, *args, **kwargs) -> Any: - ... - - def init_indexes(self, ignore_duplicate_keys: bool = False) -> None: - ... - - def as_any(self) -> Any: - ... - - # TODO: Change how we use events on the app - # def on_role_privileges_updated(self, role: Any, role_users: Any) -> None: ... diff --git a/superdesk/default_settings.py b/superdesk/default_settings.py index 43cd90ff38..e6f87dda6f 100644 --- a/superdesk/default_settings.py +++ b/superdesk/default_settings.py @@ -420,6 +420,8 @@ def local_to_utc_hour(hour): #: MODULES = [] +ASYNC_AUTH_CLASS = "superdesk.core.auth.token_auth:TokenAuthorization" + #: LDAP Server (eg: ldap://sourcefabric.org) LDAP_SERVER = env("LDAP_SERVER", "") #: LDAP Server port diff --git a/superdesk/factory/app.py b/superdesk/factory/app.py index f864f67ce9..ddef7a00c9 100644 --- a/superdesk/factory/app.py +++ b/superdesk/factory/app.py @@ -9,7 +9,7 @@ # AUTHORS and LICENSE files distributed with this source code, or # at https://www.sourcefabric.org/superdesk/license -from typing import Dict, Any, Type, List, Optional, Union, Mapping, cast, NoReturn +from typing import Dict, Any, Type, Optional, Union, Mapping, cast, NoReturn import os import eve @@ -29,7 +29,16 @@ from pymongo.errors import DuplicateKeyError from superdesk.commands import configure_cli -from superdesk.flask import g, url_for, Config, Request as FlaskRequest, abort, Blueprint, request as flask_request +from superdesk.flask import ( + g, + url_for, + Config, + Request as FlaskRequest, + Blueprint, + request as flask_request, + redirect, + session, +) from superdesk.celery_app import init_celery from superdesk.datalayer import SuperdeskDataLayer # noqa from superdesk.errors import SuperdeskError, SuperdeskApiError, DocumentError @@ -41,27 +50,96 @@ from superdesk.cache import cache_backend from .elastic_apm import setup_apm +from superdesk.core.types import ( + DefaultNoValue, + Endpoint, + Request, + RequestStorage, + RequestSessionStorageProvider, + EndpointGroup, + HTTP_METHOD, + Response, +) from superdesk.core.app import SuperdeskAsyncApp from superdesk.core.resources import ResourceModel -from superdesk.core.web import Endpoint, Request, EndpointGroup, HTTP_METHOD, Response +from superdesk.core.web import NullEndpoint SUPERDESK_PATH = os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) logger = logging.getLogger(__name__) +class FlaskStorageProvider: + @property + def _store(self): + raise NotImplementedError() + + def get(self, key: str, default: Any | None = DefaultNoValue) -> Any: + return self._store.get(key, default) if default is not DefaultNoValue else self._store.get(key) + + def set(self, key: str, value: Any) -> None: + setattr(self._store, key, value) + # self._store[key] = value + + def pop(self, key: str, default: Any | None = DefaultNoValue) -> Any: + return self._store.pop(key, default) if default is not DefaultNoValue else self._store.pop(key) + + +class FlaskSessionStorage(RequestSessionStorageProvider): + def get(self, key: str, default: Any | None = DefaultNoValue) -> Any: + return session.get(key, default) if default is not DefaultNoValue else session.get(key) + + def set(self, key: str, value: Any) -> None: + session[key] = value + + def pop(self, key: str, default: Any | None = DefaultNoValue) -> Any: + session.pop(key, default) if default is not DefaultNoValue else session.pop(key) + + def set_session_permanent(self, value: bool) -> None: + session.permanent = value + + def is_session_permanent(self) -> bool: + return session.permanent + + def clear(self): + session.clear() + + +class FlaskRequestStorage(FlaskStorageProvider): + def get(self, key: str, default: Any | None = DefaultNoValue) -> Any: + return g.get(key, default) if default is not DefaultNoValue else g.get(key) + + def set(self, key: str, value: Any) -> None: + setattr(g, key, value) + + def pop(self, key: str, default: Any | None = DefaultNoValue) -> Any: + g.pop(key, default) if default is not DefaultNoValue else g.pop(key) + + +class HttpFlaskRequestStorage(RequestStorage): + session = FlaskSessionStorage() + request = FlaskRequestStorage() + + class HttpFlaskRequest(Request): endpoint: Endpoint request: FlaskRequest + storage = HttpFlaskRequestStorage() + user: Any | None def __init__(self, endpoint: Endpoint, request: FlaskRequest): self.endpoint = endpoint self.request = request + self.user = None @property def method(self) -> HTTP_METHOD: return cast(HTTP_METHOD, self.request.method) + @property + def url(self) -> str: + return self.request.url + @property def path(self) -> str: return self.request.path @@ -79,6 +157,8 @@ async def get_data(self) -> Union[bytes, str]: return await self.request.get_data() async def abort(self, code: int, *args: Any, **kwargs: Any) -> NoReturn: + from quart import abort + abort(code, *args, **kwargs) def get_view_args(self, key: str) -> str | None: @@ -87,6 +167,9 @@ def get_view_args(self, key: str) -> str | None: def get_url_arg(self, key: str) -> str | None: return self.request.args.get(key, None) + def redirect(self, location: str, code: int = 302) -> Any: + return redirect(location, code) + def set_error_handlers(app): """Set error handlers for the given application object. @@ -126,19 +209,22 @@ def server_error_handler(error): class SuperdeskEve(eve.Eve): async_app: SuperdeskAsyncApp - _endpoints: List[Endpoint] - _endpoint_groups: List[EndpointGroup] + _endpoints: list[Endpoint] + _endpoint_groups: list[EndpointGroup] + _endpoint_lookup: dict[str, Endpoint | EndpointGroup] media: Any data: Any def __init__(self, **kwargs): - self.async_app = SuperdeskAsyncApp(self) self.json_provider_class = SuperdeskFlaskJSONProvider self._endpoints = [] self._endpoint_groups = [] - + self._endpoint_lookup = {} super().__init__(**kwargs) + self.async_app = SuperdeskAsyncApp(self) + + self.teardown_request(self._after_each_request) def __getattr__(self, name): """Only use events for on_* methods.""" @@ -226,10 +312,16 @@ def register_endpoint(self, endpoint: Endpoint | EndpointGroup): ) self._endpoints.append(endpoint) - async def _process_async_endpoint(self, **kwargs): - # Get Endpoint instance + def get_endpoint_for_current_request(self) -> Endpoint | None: + if not flask_request or flask_request.endpoint is None: + return None - endpoint_name = flask_request.endpoint + lookup_name = endpoint_name = flask_request.endpoint + + try: + return self._endpoint_lookup[lookup_name] + except KeyError: + pass # Using the requests Blueprint, determine if this request is for an EndpointGroup blueprint_name = flask_request.blueprint @@ -249,25 +341,34 @@ async def _process_async_endpoint(self, **kwargs): if endpoint is None: endpoint = next((e for e in self._endpoints if e.name == endpoint_name), None) + if endpoint is not None: + self._endpoint_lookup[lookup_name] = endpoint + + return endpoint + + def _after_each_request(self, *args, **kwargs): + g._request_instance = None + g.user_instance = None + g.company_instance = None + + async def _process_async_endpoint(self, **kwargs): + request = self.get_current_request() + # We were still unable to find the final Endpoint, return a 404 now - if endpoint is None: + if request is None: raise NotFound() - response = await endpoint( + response = await request.endpoint( kwargs, dict(flask_request.args.deepcopy()), - HttpFlaskRequest(endpoint, flask_request), + request, ) - if not isinstance(response, Response): - # We may have received a different response, such as a flask redirect call - # So we return it here - return response - elif isinstance(response.body, ResourceModel): - response.body = response.body.to_dict() - return response.body, response.status_code, response.headers + return ( + response if not isinstance(response, Response) else (response.body, response.status_code, response.headers) + ) - def get_current_user_dict(self) -> Optional[Dict[str, Any]]: + def get_current_user_dict(self) -> dict[str, Any] | None: return getattr(g, "user", None) def download_url(self, media_id: str) -> str: @@ -277,6 +378,24 @@ def download_url(self, media_id: str) -> str: def as_any(self) -> Any: return self + def get_current_request(self, req=None) -> HttpFlaskRequest | None: + try: + if not flask_request and not req: + return None + except AttributeError: + return None + + existing_instance = g.get("_request_instance", None) + if existing_instance: + return cast(HttpFlaskRequest, existing_instance) + + endpoint = self.get_endpoint_for_current_request() or NullEndpoint + new_request = HttpFlaskRequest(endpoint, req or flask_request) + g._request_instance = new_request # type: ignore[attr-defined] + if not new_request.user: + new_request.user = self.async_app.auth.get_current_user(new_request) + return new_request + def get_media_storage_class(app_config: Dict[str, Any], use_provider_config: bool = True) -> Type[MediaStorage]: if use_provider_config and app_config.get("MEDIA_STORAGE_PROVIDER"): diff --git a/tests/core/modules/company.py b/tests/core/modules/company.py index 211f86fbae..88dd072681 100644 --- a/tests/core/modules/company.py +++ b/tests/core/modules/company.py @@ -14,7 +14,7 @@ class CompanyService(AsyncResourceService[CompanyResource]): name="companies", data_class=CompanyResource, service=CompanyService, - rest_endpoints=RestEndpointConfig(), + rest_endpoints=RestEndpointConfig(auth=False), ) module = Module(name="tests.company", resources=[companies_resource_config]) diff --git a/tests/core/modules/content/resources.py b/tests/core/modules/content/resources.py index 587ffbf52d..b1d19a5f43 100644 --- a/tests/core/modules/content/resources.py +++ b/tests/core/modules/content/resources.py @@ -19,7 +19,7 @@ class ContentResourceService(AsyncResourceService[Content]): service=ContentResourceService, versioning=True, ignore_fields_in_versions=["lock_user"], - rest_endpoints=RestEndpointConfig(), + rest_endpoints=RestEndpointConfig(auth=False), mongo=MongoResourceConfig( indexes=[ MongoIndexOptions( diff --git a/tests/core/modules/topics.py b/tests/core/modules/topics.py index 04722323ae..c50f4bbf04 100644 --- a/tests/core/modules/topics.py +++ b/tests/core/modules/topics.py @@ -32,6 +32,7 @@ class UserFolderService(AsyncResourceService[UserFolder]): data_class=UserFolder, service=UserFolderService, rest_endpoints=RestEndpointConfig( + auth=False, parent_links=[ RestParentLink( resource_name=user_model_config.name, @@ -57,6 +58,7 @@ class CompanyFolderService(AsyncResourceService[CompanyFolder]): data_class=CompanyFolder, service=CompanyFolderService, rest_endpoints=RestEndpointConfig( + auth=False, parent_links=[ RestParentLink( resource_name=companies_resource_config.name, diff --git a/tests/core/modules/users/endpoints.py b/tests/core/modules/users/endpoints.py index 715ab93b2f..56a26d9642 100644 --- a/tests/core/modules/users/endpoints.py +++ b/tests/core/modules/users/endpoints.py @@ -3,7 +3,8 @@ from pydantic import BaseModel from superdesk.core.app import get_current_async_app -from superdesk.core.web import Request, Response, EndpointGroup, endpoint +from superdesk.core.types import Request, Response +from superdesk.core.web import EndpointGroup, endpoint from superdesk.errors import SuperdeskApiError @@ -15,7 +16,7 @@ class RequestParams(BaseModel): resource: Optional[str] = None -endpoints = EndpointGroup("users", __name__) +endpoints = EndpointGroup("users", __name__, auth=False) @endpoints.endpoint( @@ -46,6 +47,6 @@ async def get_user_ids(request: Request) -> Response: return Response({"ids": item_ids}, 200, ()) -@endpoint("hello/world", methods=["GET"]) +@endpoint("hello/world", methods=["GET"], auth=False) async def hello_world(request: Request) -> Response: return Response({"hello": "world"}, 200, ()) diff --git a/tests/core/modules/users/resources.py b/tests/core/modules/users/resources.py index 415aa7baef..e3bca6f7e6 100644 --- a/tests/core/modules/users/resources.py +++ b/tests/core/modules/users/resources.py @@ -35,5 +35,5 @@ class UserResourceService(AsyncResourceService[User]): ), elastic=ElasticResourceConfig(), service=UserResourceService, - rest_endpoints=RestEndpointConfig(), + rest_endpoints=RestEndpointConfig(auth=False), ) diff --git a/tests/core/resource_endpoints_test.py b/tests/core/resource_endpoints_test.py index 930f8108a3..131a2165d1 100644 --- a/tests/core/resource_endpoints_test.py +++ b/tests/core/resource_endpoints_test.py @@ -31,7 +31,7 @@ async def test_post(self, mock_utcnow): test_user.created = NOW test_user.updated = NOW test_user_dict = test_user.to_dict(context={"use_objectid": True}) - mongo_item = await self.service.mongo.find_one({"_id": test_user.id}) + mongo_item = await self.service.mongo_async.find_one({"_id": test_user.id}) test_user.etag = test_user_dict["_etag"] = mongo_item["_etag"] self.assertEqual(mongo_item, test_user_dict) @@ -181,6 +181,7 @@ async def test_search(self, mock_utcnow): async def test_hateoas(self): # Test hateoas with empty resource response = await self.test_client.get("/api/users_async") + self.assertEqual(response.status_code, 200) json_data = await response.get_json() self.assertEqual( json_data["_meta"], diff --git a/tests/core/resource_service_test.py b/tests/core/resource_service_test.py index 2b7f3f4406..609896e326 100644 --- a/tests/core/resource_service_test.py +++ b/tests/core/resource_service_test.py @@ -47,7 +47,7 @@ async def test_create(self, mock_utcnow): test_user.created = NOW test_user.updated = NOW test_user_dict = test_user.to_dict(context={"use_objectid": True}) - mongo_item = await self.service.mongo.find_one({"_id": test_user.id}) + mongo_item = await self.service.mongo_async.find_one({"_id": test_user.id}) self.assertEqual(mongo_item, test_user_dict) # Check stored `_etag` vs generated one @@ -142,7 +142,7 @@ async def test_update(self, mock_utcnow): self.assertEqual(test_user, item) # Test the User was updated in MongoDB with correct data - mongo_item = await self.service.mongo.find_one({"_id": test_user.id}) + mongo_item = await self.service.mongo_async.find_one({"_id": test_user.id}) self.assertEqual(mongo_item, test_user_dict) # Make sure ObjectIds are not stored as strings @@ -181,7 +181,7 @@ async def test_delete(self): # Now delete the user, and make sure it is gone from both MongoDB, Elastic # and the resource service await self.service.delete(test_user) - self.assertIsNone(await self.service.mongo.find_one({"_id": test_user.id})) + self.assertIsNone(await self.service.mongo_async.find_one({"_id": test_user.id})) self.assertIsNone(await self.service.elastic.find_by_id(test_user.id)) self.assertIsNone(await self.service.find_one(_id=test_user.id)) diff --git a/tests/core/resource_version_test.py b/tests/core/resource_version_test.py index 4078c48a1e..2deccdc19f 100644 --- a/tests/core/resource_version_test.py +++ b/tests/core/resource_version_test.py @@ -40,15 +40,15 @@ async def test_mongo_client(self): async def test_create_versions(self, *args): await self.service.create([Content(id="content_1", guid="content_1", headline="Some article")]) - self.assertEqual(await self.service.mongo_versioned.count_documents({"_id_document": "content_1"}), 1) - cursor = self.service.mongo_versioned.find({"_id_document": "content_1"}) + self.assertEqual(await self.service.mongo_versioned_async.count_documents({"_id_document": "content_1"}), 1) + cursor = self.service.mongo_versioned_async.find({"_id_document": "content_1"}) self.assertIsNotNone(await cursor.next()) with self.assertRaises(StopAsyncIteration): await cursor.next() await self.service.update("content_1", dict(headline="Some article updated")) - self.assertEqual(await self.service.mongo_versioned.count_documents({"_id_document": "content_1"}), 2) - cursor = self.service.mongo_versioned.find({"_id_document": "content_1"}) + self.assertEqual(await self.service.mongo_versioned_async.count_documents({"_id_document": "content_1"}), 2) + cursor = self.service.mongo_versioned_async.find({"_id_document": "content_1"}) self.assertIsNotNone(await cursor.next()) self.assertIsNotNone(await cursor.next())