diff --git a/app/client/schemas.py b/app/client/schemas.py index a87af755..cae4c497 100644 --- a/app/client/schemas.py +++ b/app/client/schemas.py @@ -1,6 +1,7 @@ -from pydantic import Field, HttpUrl +from pydantic import Field, HttpUrl, field_validator from app.schemas import CustomModel, ClientResponse, PaginationResponse +from app import constants, utils class ClientFullResponse(ClientResponse): @@ -15,31 +16,77 @@ class ClientPaginationResponse(CustomModel): class ClientCreate(CustomModel): name: str = Field( - examples=["ThirdPartyWatchlistImporter"], description="Client name" + examples=["ThirdPartyWatchlistImporter"], + description="Client name", + min_length=3, + max_length=constants.MAX_CLIENT_NAME_LENGTH, ) description: str = Field( examples=["Client that imports watchlist from third-party services"], description="Short clear description of the client", + min_length=3, + max_length=constants.MAX_CLIENT_DESCRIPTION_LENGTH, ) endpoint: HttpUrl = Field( examples=["https://example.com", "http://localhost/auth/confirm"], description="Endpoint of the client. " "User will be redirected to that endpoint after successful " "authorization", + max_length=constants.MAX_CLIENT_ENDPOINT_LENGTH, ) + @field_validator("name", "description", mode="before") + def validate_name(cls, v: str) -> str: + if not isinstance(v, str): + return v + + return utils.remove_bad_characters(v).strip() + + @field_validator("endpoint") + def validate_endpoint(cls, v: HttpUrl) -> HttpUrl: + if len(str(v)) > constants.MAX_CLIENT_ENDPOINT_LENGTH: + raise ValueError( + f"Endpoint length should be less than {constants.MAX_CLIENT_ENDPOINT_LENGTH}" + ) + + return v + class ClientUpdate(CustomModel): name: str | None = Field( None, description="Client name", + max_length=constants.MAX_CLIENT_NAME_LENGTH, + min_length=3, ) description: str | None = Field( None, description="Short clear description of the client", + max_length=constants.MAX_CLIENT_DESCRIPTION_LENGTH, + min_length=3, + ) + endpoint: HttpUrl | None = Field( + None, + description="Endpoint of the client", + max_length=constants.MAX_CLIENT_ENDPOINT_LENGTH, ) - endpoint: HttpUrl | None = Field(None, description="Endpoint of the client") revoke_secret: bool = Field( False, description="Create new client secret and revoke previous", ) + + @field_validator("name", "description", mode="before") + def validate_name(cls, v: str) -> str: + if not isinstance(v, str): + return v + + return utils.remove_bad_characters(v).strip() + + @field_validator("endpoint") + def validate_endpoint(cls, v: HttpUrl | None) -> HttpUrl | None: + if len(str(v)) > constants.MAX_CLIENT_ENDPOINT_LENGTH: + raise ValueError( + f"Endpoint length should be less than {constants.MAX_CLIENT_ENDPOINT_LENGTH}" + ) + + return v diff --git a/app/constants.py b/app/constants.py index aff52d3a..708507f4 100644 --- a/app/constants.py +++ b/app/constants.py @@ -101,6 +101,10 @@ MAX_USER_CLIENTS = 10 +MAX_CLIENT_NAME_LENGTH = 128 +MAX_CLIENT_DESCRIPTION_LENGTH = 512 +MAX_CLIENT_ENDPOINT_LENGTH = 128 + # Meilisearch index names SEARCH_INDEX_CHARACTERS = "content_characters" SEARCH_INDEX_COMPANIES = "content_companies" diff --git a/tests/client/test_client_create.py b/tests/client/test_client_create.py index 4d3d0f8b..94c0dd3a 100644 --- a/tests/client/test_client_create.py +++ b/tests/client/test_client_create.py @@ -1,6 +1,7 @@ from starlette import status from tests.client_requests import request_client_create +from app import constants async def test_client_create(client, test_token, test_user): @@ -51,3 +52,83 @@ async def test_client_create_double(client, test_token): ) assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.json()["code"] == "client:already_exists" + + +async def test_too_long_fields(client, test_token): + error_message_format = "Invalid field {field} in request body" + error_code = "system:validation_error" + + response = await request_client_create( + client, + test_token, + "a" * (constants.MAX_CLIENT_NAME_LENGTH + 1), + "description", + "http://localhost/", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="name" + ) + + response = await request_client_create( + client, + test_token, + "name", + "a" * (constants.MAX_CLIENT_DESCRIPTION_LENGTH + 1), + "http://localhost/", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="description" + ) + + response = await request_client_create( + client, + test_token, + "name", + "description", + "http://localhost/" + "a" * (constants.MAX_CLIENT_ENDPOINT_LENGTH + 1), + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="endpoint" + ) + + +async def test_too_short_fields(client, test_token): + error_message_format = "Invalid field {field} in request body" + error_code = "system:validation_error" + + response = await request_client_create( + client, + test_token, + "a", + "description", + "http://localhost/", + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="name" + ) + + response = await request_client_create( + client, + test_token, + "name", + "a", + "http://localhost/", + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="description" + ) diff --git a/tests/client/test_client_update.py b/tests/client/test_client_update.py index 781dacbd..5c33de1d 100644 --- a/tests/client/test_client_update.py +++ b/tests/client/test_client_update.py @@ -3,6 +3,7 @@ from starlette import status from tests.client_requests import request_client_create, request_client_update +from app import constants async def test_client_update(client, test_token): @@ -52,3 +53,95 @@ async def test_client_update_nonexistent(client, test_token): assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json()["code"] == "client:not_found" + + +async def test_too_long_fields(client, test_token): + error_message_format = "Invalid field {field} in request body" + error_code = "system:validation_error" + + name = "test-client" + description = "test client description" + endpoint = "http://localhost/" + + response = await request_client_create( + client, test_token, name, description, endpoint + ) + assert response.status_code == status.HTTP_200_OK + + client_reference = response.json()["reference"] + + response = await request_client_update( + client, + test_token, + client_reference, + name="a" * (constants.MAX_CLIENT_NAME_LENGTH + 1), + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="name" + ) + + response = await request_client_update( + client, + test_token, + client_reference, + description="a" * (constants.MAX_CLIENT_DESCRIPTION_LENGTH + 1), + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="description" + ) + + response = await request_client_update( + client, + test_token, + client_reference, + endpoint="http://localhost/" + + "a" * (constants.MAX_CLIENT_ENDPOINT_LENGTH + 1), + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="endpoint" + ) + + +async def test_too_short_fields(client, test_token): + error_message_format = "Invalid field {field} in request body" + error_code = "system:validation_error" + + name = "test-client" + description = "test client description" + endpoint = "http://localhost/" + + response = await request_client_create( + client, test_token, name, description, endpoint + ) + assert response.status_code == status.HTTP_200_OK + + client_reference = response.json()["reference"] + + response = await request_client_update( + client, test_token, client_reference, name="a" + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="name" + ) + + response = await request_client_update( + client, test_token, client_reference, description="a" + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + assert response.json()["code"] == error_code + assert response.json()["message"] == error_message_format.format( + field="description" + )