-
-
Notifications
You must be signed in to change notification settings - Fork 390
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enhancement for FastAPI lifespan support (#1371)
- Loading branch information
Showing
9 changed files
with
335 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# mypy: no-disallow-untyped-decorators | ||
# pylint: disable=E0611,E0401 | ||
import os | ||
|
||
import pytest | ||
from asgi_lifespan import LifespanManager | ||
from httpx import AsyncClient | ||
from main import LOG_FILE, app | ||
from models import Users | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def anyio_backend(): | ||
return "asyncio" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
async def client(): | ||
if LOG_FILE.exists(): | ||
LOG_FILE.unlink() | ||
async with LifespanManager(app): | ||
async with AsyncClient(app=app, base_url="http://test") as c: | ||
yield c | ||
assert not LOG_FILE.exists() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_create_user(client: AsyncClient): # nosec | ||
response = await client.post("/users", json={"username": "admin"}) | ||
assert response.status_code == 200, response.text | ||
data = response.json() | ||
assert data["username"] == "admin" | ||
assert "id" in data | ||
user_id = data["id"] | ||
|
||
user_obj = await Users.get(id=user_id) | ||
assert user_obj.id == user_id | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_lifespan(client: AsyncClient): # nosec | ||
if os.getenv("USE_LIFESPAN"): | ||
assert LOG_FILE.exists() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# pylint: disable=E0611,E0401 | ||
import os | ||
from contextlib import asynccontextmanager | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from fastapi import FastAPI | ||
from models import User_Pydantic, UserIn_Pydantic, Users | ||
from pydantic import BaseModel | ||
from starlette.exceptions import HTTPException | ||
|
||
from tortoise.contrib.fastapi import register_tortoise | ||
|
||
LOG_FILE = Path(__file__).parent / "foo.log" | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
print("app startup") | ||
if not LOG_FILE.exists(): | ||
LOG_FILE.touch() | ||
yield | ||
print("app teardown") | ||
if LOG_FILE.exists(): | ||
LOG_FILE.unlink() | ||
|
||
|
||
if os.getenv("USE_LIFESPAN"): | ||
app = FastAPI(title="Tortoise ORM FastAPI test", lifespan=lifespan) | ||
else: | ||
app = FastAPI(title="Tortoise ORM FastAPI test") | ||
|
||
|
||
class Status(BaseModel): | ||
message: str | ||
|
||
|
||
@app.get("/users", response_model=List[User_Pydantic]) | ||
async def get_users(): | ||
return await User_Pydantic.from_queryset(Users.all()) | ||
|
||
|
||
@app.post("/users", response_model=User_Pydantic) | ||
async def create_user(user: UserIn_Pydantic): | ||
user_obj = await Users.create(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_tortoise_orm(user_obj) | ||
|
||
|
||
@app.get("/user/{user_id}", response_model=User_Pydantic) | ||
async def get_user(user_id: int): | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app.put("/user/{user_id}", response_model=User_Pydantic) | ||
async def update_user(user_id: int, user: UserIn_Pydantic): | ||
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True)) | ||
return await User_Pydantic.from_queryset_single(Users.get(id=user_id)) | ||
|
||
|
||
@app.delete("/user/{user_id}", response_model=Status) | ||
async def delete_user(user_id: int): | ||
deleted_count = await Users.filter(id=user_id).delete() | ||
if not deleted_count: | ||
raise HTTPException(status_code=404, detail=f"User {user_id} not found") | ||
return Status(message=f"Deleted user {user_id}") | ||
|
||
|
||
register_tortoise( | ||
app, | ||
db_url="sqlite://:memory:", | ||
modules={"models": ["models"]}, | ||
generate_schemas=True, | ||
add_exception_handlers=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from tortoise import fields, models | ||
from tortoise.contrib.pydantic import pydantic_model_creator | ||
|
||
|
||
class Users(models.Model): | ||
""" | ||
The User model | ||
""" | ||
|
||
id = fields.IntField(pk=True) | ||
#: This is a username | ||
username = fields.CharField(max_length=20, unique=True) | ||
name = fields.CharField(max_length=50, null=True) | ||
family_name = fields.CharField(max_length=50, null=True) | ||
category = fields.CharField(max_length=30, default="misc") | ||
password_hash = fields.CharField(max_length=128, null=True) | ||
created_at = fields.DatetimeField(auto_now_add=True) | ||
modified_at = fields.DatetimeField(auto_now=True) | ||
|
||
def full_name(self) -> str: | ||
""" | ||
Returns the best name | ||
""" | ||
if self.name or self.family_name: | ||
return f"{self.name or ''} {self.family_name or ''}".strip() | ||
return self.username | ||
|
||
class PydanticMeta: | ||
computed = ["full_name"] | ||
exclude = ["password_hash"] | ||
|
||
|
||
User_Pydantic = pydantic_model_creator(Users, name="User") | ||
UserIn_Pydantic = pydantic_model_creator(Users, name="UserIn", exclude_readonly=True) |
Oops, something went wrong.