Skip to content

Commit

Permalink
server/article: fix body validation
Browse files Browse the repository at this point in the history
  • Loading branch information
frankie567 committed Mar 21, 2024
1 parent 205064a commit 5343131
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 19 deletions.
34 changes: 33 additions & 1 deletion server/polar/article/schemas.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import base64
import datetime
import re
from typing import Literal, Self
from uuid import UUID

from pydantic import Field, HttpUrl
from pydantic import Field, HttpUrl, model_validator

from polar.kit.schemas import Schema
from polar.models.article import Article as ArticleModel
Expand Down Expand Up @@ -203,6 +204,22 @@ class ArticleCreate(Schema):
default=None, description="Custom og:description value"
)

@model_validator(mode="after")
def check_either_body_or_body_base64(self) -> Self:
if self.body is not None and self.body_base64 is not None:
raise ValueError(
"Only one of body or body_base64 can be provided, not both."
)
if self.body is None and self.body_base64 is None:
raise ValueError("Either body or body_base64 must be provided.")
return self

def get_body(self) -> str:
if self.body is not None:
return self.body
assert self.body_base64 is not None
return base64.b64decode(self.body_base64).decode("utf-8")


class ArticleUpdate(Schema):
title: str | None = None
Expand Down Expand Up @@ -260,6 +277,21 @@ class ArticleUpdate(Schema):
default=None, description="Custom og:description value"
)

@model_validator(mode="after")
def check_either_body_or_body_base64(self) -> Self:
if self.body is not None and self.body_base64 is not None:
raise ValueError(
"Only one of body or body_base64 can be provided, not both."
)
return self

def get_body(self) -> str | None:
if self.body is not None:
return self.body
if self.body_base64 is not None:
return base64.b64decode(self.body_base64).decode("utf-8")
return None


class ArticleViewedResponse(Schema):
ok: bool
Expand Down
21 changes: 3 additions & 18 deletions server/polar/article/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import base64
from collections.abc import Sequence
from datetime import datetime
from operator import and_, or_
Expand Down Expand Up @@ -75,16 +74,6 @@ async def create(
"This slug has been used more than 100 times in this organization."
)

body: str | None
if create_schema.body is not None and create_schema.body_base64 is not None:
raise BadRequest("body and body_base64 are mutually exclusive")
if create_schema.body is not None:
body = create_schema.body
if create_schema.body_base64 is not None:
body = base64.b64decode(create_schema.body_base64).decode("utf-8")
if body is None:
raise BadRequest("No body provided")

published_at: datetime | None = None
if create_schema.visibility == "public":
published_at = utc_now()
Expand All @@ -94,7 +83,7 @@ async def create(
article = Article(
slug=slug,
title=create_schema.title,
body=body,
body=create_schema.get_body(),
created_by=subject.id,
organization_id=create_schema.organization_id,
byline=create_schema.byline,
Expand Down Expand Up @@ -258,12 +247,8 @@ async def update(
if update.slug is not None:
article.slug = polar_slugify(update.slug)

if update.body is not None and update.body_base64 is not None:
raise BadRequest("body and body_base64 are mutually exclusive")
if update.body is not None:
article.body = update.body
if update.body_base64 is not None:
article.body = base64.b64decode(update.body_base64).decode("utf-8")
if body := update.get_body():
article.body = body

if update.byline is not None:
article.byline = (
Expand Down
25 changes: 25 additions & 0 deletions server/tests/article/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,31 @@
from tests.fixtures.random_objects import create_user


@pytest.mark.asyncio
@pytest.mark.http_auto_expunge
async def test_create_no_body(
user: User,
organization: Organization,
user_organization: UserOrganization, # makes User a member of Organization
auth_jwt: str,
client: AsyncClient,
save_fixture: SaveFixture,
) -> None:
user_organization.is_admin = True
await save_fixture(user_organization)

response = await client.post(
"/api/v1/articles",
json={
"title": "Hello World!",
"organization_id": str(organization.id),
},
cookies={settings.AUTH_COOKIE_KEY: auth_jwt},
)

assert response.status_code == 422


@pytest.mark.asyncio
@pytest.mark.http_auto_expunge
async def test_create(
Expand Down

0 comments on commit 5343131

Please sign in to comment.