Skip to content

Commit

Permalink
feat: add correct tests for annotations and observations
Browse files Browse the repository at this point in the history
  • Loading branch information
Nastaliss committed Apr 1, 2024
1 parent ae0ed9b commit ceec8b5
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/app/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pydantic import BaseModel, Field, validator

from app.db.models import AccessType, MediaType
from app.db.models import AccessType, MediaType, ObservationType


# Template classes
Expand Down Expand Up @@ -88,7 +88,7 @@ class MediaUrl(BaseModel):
# Annotation
class AnnotationIn(BaseModel):
media_id: int = Field(..., gt=0)
observations: List[str] = Field(..., min_items=0)
observations: List[ObservationType]


class AnnotationOut(AnnotationIn, _CreatedAt, _Id):
Expand Down
9 changes: 8 additions & 1 deletion src/app/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,19 @@ def __repr__(self):
return f"<Media(bucket_key='{self.bucket_key}', type='{self.type}'>"


class ObservationType(str, enum.Enum):
fire: str = "fire"
smoke: str = "smoke"
clouds: str = "clouds"
sky: str = "sky"
fog: str = "fog"

class Annotations(Base):
__tablename__ = "annotations"

id = Column(Integer, primary_key=True)
media_id = Column(Integer, ForeignKey("media.id"))
observations = Column(ARRAY(String(50)), nullable=False)
observations = Column(ARRAY(Enum(ObservationType)), nullable=False)
created_at = Column(DateTime, default=func.now())

media = relationship("Media", uselist=False, back_populates="annotations")
Expand Down
2 changes: 1 addition & 1 deletion src/tests/crud/test_authorizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


ANNOTATIONS_TABLE = [
{"id": 1, "media_id": 1, "created_at": "2020-10-13T08:18:45.447773"},
{"id": 1, "media_id": 1, "observations": [], "created_at": "2020-10-13T08:18:45.447773"},
]


Expand Down
2 changes: 1 addition & 1 deletion src/tests/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def fill_table(test_db: Database, table: Table, entries: List[Dict[str, An
are not incremented if the "id" field is included
"""
if remove_ids:
entries = [{k: v for k, v in x.items() if k != "id"} for x in entries]
entries = [{k: v for k, v in entry.items() if k != "id"} for entry in entries]

query = table.insert().values(entries)
await test_db.execute(query=query)
Expand Down
39 changes: 26 additions & 13 deletions src/tests/routes/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from tests.db_utils import TestSessionLocal, fill_table, get_entry
from tests.utils import update_only_datetime

from app.db.models import ObservationType

ACCESS_TABLE = [
{"id": 1, "login": "first_login", "hashed_password": "hashed_pwd", "scope": "user"},
{"id": 2, "login": "second_login", "hashed_password": "hashed_pwd", "scope": "admin"},
Expand All @@ -20,7 +22,7 @@
]

ANNOTATIONS_TABLE = [
{"id": 1, "media_id": 1, "observations": ["fire", "smoke", "dog", "coffee"], "created_at": "2020-10-13T08:18:45.447773"},
{"id": 1, "media_id": 1, "observations": [ObservationType.fire, ObservationType.smoke, ObservationType.clouds], "created_at": "2020-10-13T08:18:45.447773"},
{"id": 2, "media_id": 2, "observations": [], "created_at": "2022-10-13T08:18:45.447773"},
]

Expand Down Expand Up @@ -95,25 +97,31 @@ async def test_fetch_annotations(
@pytest.mark.parametrize(
"access_idx, payload, status_code, status_details",
[
[None, {"media_id": 1}, 401, "Not authenticated"],
[0, {"media_id": 1}, 201, None],
[1, {"media_id": 1}, 201, None],
[1, {"media_id": "alpha"}, 422, None],
[None, {"media_id": 1, "observations": []}, 401, "Not authenticated"],
[0, {"media_id": 1, "observations": []}, 201, None],
[1, {"media_id": 1, "observations": []}, 201, None],
[1, {"media_id": 1, "observations": ["clouds"]}, 201, None],
[1, {"media_id": 1, "observations": ["clouds", "fire", "smoke"]}, 201, None],
[1, {"media_id": 1, "observations": ["clouds", "fire", "puppy"]}, 422, None],
[1, {"media_id": 1, "observations": [1337]}, 422, None],
[1, {"media_id": 1, "observations": "smoke"}, 422, None],
[1, {"media_id": "alpha", "observations": []}, 422, None],
[1, {}, 422, None],
[1, {"media_id": 1}, 422, None],
[1, {"observations": []}, 422, None],
],
)
@pytest.mark.asyncio
async def test_create_annotation(
test_app_asyncio, init_test_db, test_db, access_idx, payload, status_code, status_details
):

# Create a custom access token
auth = None
if isinstance(access_idx, int):
auth = await pytest.get_token(ACCESS_TABLE[access_idx]["id"], ACCESS_TABLE[access_idx]["scope"].split())

utc_dt = datetime.utcnow()
response = await test_app_asyncio.post("/annotations/", data=json.dumps(payload), headers=auth)
response = await test_app_asyncio.post("/annotations/", json=payload, headers=auth)
assert response.status_code == status_code
if isinstance(status_details, str):
assert response.json()["detail"] == status_details
Expand All @@ -133,13 +141,18 @@ async def test_create_annotation(
@pytest.mark.parametrize(
"access_idx, payload, annotation_id, status_code, status_details",
[
[None, {"media_id": 1}, 1, 401, "Not authenticated"],
[0, {"media_id": 1}, 1, 403, "Your access scope is not compatible with this operation."],
[1, {"media_id": 1}, 1, 200, None],
[None, {"media_id": 1, "observations": []}, 1, 401, "Not authenticated"],
[0, {"media_id": 1, "observations": []}, 1, 403, "Your access scope is not compatible with this operation."],
[1, {"media_id": 1, "observations": []}, 1, 200, None],
[1, {"media_id": 1, "observations": [1337]}, 1, 422, None],
[1, {"media_id": 1, "observations": ["smoke"]}, 1, 200, None],
[1, {"media_id": 1, "observations": ["smoke", "fire", "puppy"]}, 1, 422, None],
[1, {}, 1, 422, None],
[1, {"media_id": "alpha"}, 1, 422, None],
[1, {"media_id": 1}, 999, 404, "Table annotations has no entry with id=999"],
[1, {"media_id": 1}, 0, 422, None],
[1, {"media_id": 1,}, 1, 422, None],
[1, {"observations": []}, 1, 422, None],
[1, {"media_id": "alpha", "observations": []}, 1, 422, None],
[1, {"media_id": 1, "observations": []}, 999, 404, "Table annotations has no entry with id=999"],
[1, {"media_id": 1, "observations": []}, 0, 422, None],
],
)
@pytest.mark.asyncio
Expand Down

0 comments on commit ceec8b5

Please sign in to comment.