Skip to content

Commit 94639f6

Browse files
committed
Update old schemas and routes to use new ORM layout
1 parent e67fc93 commit 94639f6

File tree

9 files changed

+92
-89
lines changed

9 files changed

+92
-89
lines changed

api/models/schemas/old/infraction.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1-
from enum import Enum
1+
from pydantic import BaseModel, validator
22

3-
from pydantic import BaseModel
3+
from api.models.orm.infraction import InfractionType
4+
from api.models.schemas.utils import discord_ids_must_be_snowflake
45

56

6-
class InfractionType(str, Enum):
7-
"""An enumeration of codejam infraction types."""
8-
9-
note = "note"
10-
ban = "ban"
11-
warning = "warning"
12-
13-
14-
class Infraction(BaseModel):
15-
"""A model representing an infraction."""
7+
class InfractionBase(BaseModel):
8+
"""Base model for all infraction types."""
169

1710
user_id: int
1811
jam_id: int
1912
reason: str
2013
infraction_type: InfractionType
2114

15+
# validators
16+
_ensure_valid_discord_id = validator("user_id", allow_reuse=True)(discord_ids_must_be_snowflake)
2217

23-
class InfractionResponse(Infraction):
24-
"""Response model representing an infraction."""
18+
19+
class InfractionCreate(InfractionBase):
20+
"""The expected fields to create a new infraction."""
21+
22+
23+
class Infraction(InfractionBase):
24+
"""A model representing an infraction."""
2525

2626
id: int
2727

api/models/schemas/old/jam.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
from api.models.schemas.old import infraction, team, winner
44

55

6-
class CodeJam(BaseModel):
7-
"""A model representing a codejam."""
6+
class CodeJamBase(BaseModel):
7+
"""A Base model representing a codejam."""
88

99
name: str
1010
teams: list[team.Team]
1111
ongoing: bool = False
1212

1313

14-
class CodeJamResponse(CodeJam):
14+
class CodeJamCreate(CodeJamBase):
15+
"""The expected fields to create a new Code Jam."""
16+
17+
18+
class CodeJam(CodeJamBase):
1519
"""Response model representing a code jam."""
1620

1721
id: int
18-
teams: list[team.TeamResponse]
19-
infractions: list[infraction.InfractionResponse]
22+
infractions: list[infraction.Infraction]
2023
winners: list[winner.Winner]
2124

2225
class Config:

api/models/schemas/old/team.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from api.models.schemas.old import user
66

77

8-
class Team(BaseModel):
9-
"""A model representing a team for a codejam."""
8+
class TeamBase(BaseModel):
9+
"""A Base model representing a team for a codejam."""
1010

1111
name: str
1212
users: list[user.User]
1313
discord_role_id: Optional[int] = None
1414
discord_channel_id: Optional[int] = None
1515

1616

17-
class TeamResponse(Team):
17+
class Team(TeamBase):
1818
"""Response model representing a team."""
1919

2020
id: int
@@ -24,16 +24,3 @@ class Config:
2424
"""Sets ORM mode to true so that pydantic will validate the objects returned by SQLAlchemy."""
2525

2626
orm_mode = True
27-
28-
29-
class UserTeamResponse(BaseModel):
30-
"""Response model representing user and team relationship."""
31-
32-
user_id: int
33-
team: TeamResponse
34-
is_leader: bool
35-
36-
class Config:
37-
"""Sets ORM mode to true so that pydantic will validate the objects returned by SQLAlchemy."""
38-
39-
orm_mode = True

api/models/schemas/old/user.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ParticipationHistory(BaseModel):
2323
first_place: bool
2424
team_id: int
2525
is_leader: bool
26-
infractions: list[infraction.InfractionResponse]
26+
infractions: list[infraction.Infraction]
2727

2828
class Config:
2929
"""Sets ORM mode to true so that pydantic will validate the objects returned by SQLAlchemy."""

api/models/schemas/utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def discord_ids_must_be_snowflake(field_to_check: int) -> int:
2+
"""Ensure the ids are valid Discord snowflakes."""
3+
if field_to_check and field_to_check.bit_length() > 64:
4+
raise ValueError("Field must fit within a 64 bit int.")
5+
return field_to_check

api/routers/old/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from fastapi import APIRouter
22

3-
from api.routers.old import codejams, infractions, teams, users, winners
3+
from api.routers.old import codejams # , infractions, teams, users, winners
44

55
old_routes_router = APIRouter()
66
old_routes_router.include_router(codejams.router)
7-
old_routes_router.include_router(infractions.router)
8-
old_routes_router.include_router(teams.router)
9-
old_routes_router.include_router(users.router)
10-
old_routes_router.include_router(winners.router)
7+
# old_routes_router.include_router(infractions.router)
8+
# old_routes_router.include_router(teams.router)
9+
# old_routes_router.include_router(users.router)
10+
# old_routes_router.include_router(winners.router)

api/routers/old/codejams.py

+28-25
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
from sqlalchemy import desc, update
55
from sqlalchemy.future import select
66

7-
from api.models import CodeJam, CodeJamResponse
87
from api.models.orm import Jam, Team, User
8+
from api.models.schemas.old import jam
99
from api.settings import DBSession
1010

1111
router = APIRouter(prefix="/codejams", tags=["codejams"])
1212

1313

1414
@router.get("/")
15-
async def get_codejams(session: DBSession) -> list[CodeJamResponse]:
15+
async def get_codejams(session: DBSession) -> list[jam.CodeJam]:
1616
"""Get all the codejams stored in the database."""
17-
codejams = await session.execute(select(Jam).order_by(desc(Jam.id)))
17+
codejams = await session.execute(select(Jam).order_by(desc(Jam.jam_id)))
1818
codejams.unique()
1919

2020
return codejams.scalars().all()
@@ -24,7 +24,7 @@ async def get_codejams(session: DBSession) -> list[CodeJamResponse]:
2424
"/{codejam_id}",
2525
responses={404: {"description": "CodeJam could not be found or there is no ongoing code jam."}},
2626
)
27-
async def get_codejam(codejam_id: int, session: DBSession) -> CodeJamResponse:
27+
async def get_codejam(codejam_id: int, session: DBSession) -> jam.CodeJam:
2828
"""
2929
Get a specific codejam stored in the database by ID.
3030
@@ -39,7 +39,7 @@ async def get_codejam(codejam_id: int, session: DBSession) -> CodeJamResponse:
3939
# With the current implementation, there should only be one ongoing codejam.
4040
return ongoing_jams[0]
4141

42-
jam_result = await session.execute(select(Jam).where(Jam.id == codejam_id))
42+
jam_result = await session.execute(select(Jam).where(Jam.jam_id == codejam_id))
4343
jam_result.unique()
4444

4545
if not (jam := jam_result.scalars().one_or_none()):
@@ -54,23 +54,23 @@ async def modify_codejam(
5454
session: DBSession,
5555
name: Optional[str] = None,
5656
ongoing: Optional[bool] = None,
57-
) -> CodeJamResponse:
57+
) -> jam.CodeJam:
5858
"""Modify the specified codejam to change its name and/or whether it's the ongoing code jam."""
59-
codejam = await session.execute(select(Jam).where(Jam.id == codejam_id))
59+
codejam = await session.execute(select(Jam).where(Jam.jam_id == codejam_id))
6060
codejam.unique()
6161

6262
if not codejam.scalars().one_or_none():
6363
raise HTTPException(status_code=404, detail="Code Jam with specified ID does not exist.")
6464

6565
if name is not None:
66-
await session.execute(update(Jam).where(Jam.id == codejam_id).values(name=name))
66+
await session.execute(update(Jam).where(Jam.jam_id == codejam_id).values(name=name))
6767

6868
if ongoing is not None:
6969
# Make sure no other Jams are ongoing, and set the specified codejam to ongoing.
7070
await session.execute(update(Jam).where(Jam.ongoing == True).values(ongoing=False))
71-
await session.execute(update(Jam).where(Jam.id == codejam_id).values(ongoing=True))
71+
await session.execute(update(Jam).where(Jam.jam_id == codejam_id).values(ongoing=True))
7272

73-
jam_result = await session.execute(select(Jam).where(Jam.id == codejam_id))
73+
jam_result = await session.execute(select(Jam).where(Jam.jam_id == codejam_id))
7474
jam_result.unique()
7575

7676
jam = jam_result.scalars().one()
@@ -79,7 +79,7 @@ async def modify_codejam(
7979

8080

8181
@router.post("/")
82-
async def create_codejam(codejam: CodeJam, session: DBSession) -> CodeJamResponse:
82+
async def create_codejam(codejam: jam.CodeJamCreate, session: DBSession) -> jam.CodeJam:
8383
"""
8484
Create a new codejam and get back the one just created.
8585
@@ -94,34 +94,37 @@ async def create_codejam(codejam: CodeJam, session: DBSession) -> CodeJamRespons
9494
await session.flush()
9595

9696
for raw_team in codejam.teams:
97-
team = Team(
98-
jam_id=jam.id,
99-
name=raw_team.name,
100-
discord_role_id=raw_team.discord_role_id,
101-
discord_channel_id=raw_team.discord_channel_id,
102-
)
103-
session.add(team)
104-
# Flush here to receive team ID
105-
await session.flush()
106-
97+
created_users = []
10798
for raw_user in raw_team.users:
99+
if raw_user.is_leader:
100+
team_leader_id = raw_user.user_id
108101
if (
109-
not (await session.execute(select(User).where(User.id == raw_user.user_id)))
102+
not (await session.execute(select(User).where(User.user_id == raw_user.user_id)))
110103
.unique()
111104
.scalars()
112105
.one_or_none()
113106
):
114107
user = User(id=raw_user.user_id)
108+
created_users.append(user)
115109
session.add(user)
116110

117-
team_user = TeamUser(team_id=team.id, user_id=raw_user.user_id, is_leader=raw_user.is_leader)
118-
session.add(team_user)
111+
team = Team(
112+
jam_id=jam.jam_id,
113+
name=raw_team.name,
114+
discord_role_id=raw_team.discord_role_id,
115+
discord_channel_id=raw_team.discord_channel_id,
116+
team_leader_id=team_leader_id,
117+
)
118+
team.users = created_users
119+
session.add(team)
120+
# Flush here to receive team ID
121+
await session.flush()
119122

120123
await session.flush()
121124

122125
# Pydantic, what is synchronous, may attempt to call async methods if current jam
123126
# object is returned. To avoid this, fetch all data here, in async context.
124-
jam_result = await session.execute(select(Jam).where(Jam.id == jam.id))
127+
jam_result = await session.execute(select(Jam).where(Jam.jam_id == jam.jam_id))
125128
jam_result.unique()
126129

127130
jam = jam_result.scalars().one()

api/routers/old/infractions.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from fastapi import APIRouter, HTTPException
22
from sqlalchemy.future import select
33

4-
from api.models import Infraction, InfractionResponse
54
from api.models.orm import Infraction as DbInfraction
65
from api.models.orm import Jam, User
6+
from api.models.schemas.old.infraction import Infraction, InfractionCreate
77
from api.settings import DBSession
88

99
router = APIRouter(prefix="/infractions", tags=["infractions"])
1010

1111

1212
@router.get("/")
13-
async def get_infractions(session: DBSession) -> list[InfractionResponse]:
13+
async def get_infractions(session: DBSession) -> list[Infraction]:
1414
"""Get every infraction stored in the database."""
1515
infractions = await session.execute(select(DbInfraction))
1616
infractions.unique()
@@ -22,9 +22,9 @@ async def get_infractions(session: DBSession) -> list[InfractionResponse]:
2222
"/{infraction_id}",
2323
responses={404: {"description": "Infraction could not be found."}},
2424
)
25-
async def get_infraction(infraction_id: int, session: DBSession) -> InfractionResponse:
25+
async def get_infraction(infraction_id: int, session: DBSession) -> Infraction:
2626
"""Get a specific infraction stored in the database by ID."""
27-
infraction_result = await session.execute(select(DbInfraction).where(DbInfraction.id == infraction_id))
27+
infraction_result = await session.execute(select(DbInfraction).where(DbInfraction.infraction_id == infraction_id))
2828
infraction_result.unique()
2929

3030
if not (infraction := infraction_result.scalars().one_or_none()):
@@ -38,16 +38,18 @@ async def get_infraction(infraction_id: int, session: DBSession) -> InfractionRe
3838
responses={404: {"Description": "Jam ID or User ID could not be found."}},
3939
)
4040
async def create_infraction(
41-
infraction: Infraction,
41+
infraction: InfractionCreate,
4242
session: DBSession,
43-
) -> InfractionResponse:
43+
) -> Infraction:
4444
"""Add an infraction for a user to the database."""
45-
jam_id = (await session.execute(select(Jam.id).where(Jam.id == infraction.jam_id))).scalars().one_or_none()
45+
jam_id = (await session.execute(select(Jam.jam_id).where(Jam.jam_id == infraction.jam_id))).scalars().one_or_none()
4646

4747
if jam_id is None:
4848
raise HTTPException(404, "Jam with specified ID could not be found.")
4949

50-
user_id = (await session.execute(select(User.id).where(User.id == infraction.user_id))).scalars().one_or_none()
50+
user_id = (
51+
(await session.execute(select(User.user_id).where(User.user_id == infraction.user_id))).scalars().one_or_none()
52+
)
5153

5254
if user_id is None:
5355
raise HTTPException(404, "User with specified ID could not be found.")
@@ -58,7 +60,7 @@ async def create_infraction(
5860
session.add(infraction)
5961
await session.flush()
6062

61-
infraction_result = await session.execute(select(DbInfraction).where(DbInfraction.id == infraction.id))
63+
infraction_result = await session.execute(select(DbInfraction).where(DbInfraction.infraction_id == infraction.id))
6264
infraction_result.unique()
6365

6466
return infraction_result.scalars().one()

0 commit comments

Comments
 (0)