Skip to content

Commit

Permalink
Add support for sqlalchemy from statement (#1277)
Browse files Browse the repository at this point in the history
  • Loading branch information
uriyyo authored Sep 8, 2024
1 parent c6d520b commit 49734f5
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/pagination_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from faker import Faker
from fastapi import Depends, FastAPI
from pydantic import BaseModel
from sqlalchemy import create_engine, select
from sqlalchemy import create_engine, select, text
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, Session, mapped_column

from fastapi_pagination import LimitOffsetPage, Page, add_pagination
Expand Down Expand Up @@ -74,8 +74,8 @@ def create_user(user_in: UserIn, db: Session = Depends(get_db)) -> UserOut:
@app.get("/users/default", response_model=Page[UserOut])
@app.get("/users/limit-offset", response_model=LimitOffsetPage[UserOut])
def get_users(db: Session = Depends(get_db)) -> Any:
return paginate(db, select(User))
return paginate(db, select(User).from_statement(text("""SELECT * FROM users""")))


if __name__ == "__main__":
uvicorn.run("pagination_sqlalchemy:app")
uvicorn.run(app)
14 changes: 12 additions & 2 deletions fastapi_pagination/ext/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from sqlalchemy import func, select, text
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import Query, Session, noload, scoped_session
from sqlalchemy.orm import FromStatement, Query, Session, noload, scoped_session
from sqlalchemy.sql.elements import TextClause
from typing_extensions import TypeAlias, deprecated

Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, *_: Any, **__: Any) -> None:
AsyncConn: TypeAlias = "Union[AsyncSession, AsyncConnection, async_scoped_session]"
SyncConn: TypeAlias = "Union[Session, Connection, scoped_session]"

Selectable: TypeAlias = "Union[Select, TextClause]"
Selectable: TypeAlias = "Union[Select, TextClause, FromStatement]"


def create_paginate_query_from_text(query: str, params: AbstractParams) -> str:
Expand Down Expand Up @@ -89,16 +89,26 @@ def paginate_query(query: Select, params: AbstractParams) -> Select:
return create_paginate_query(query, params) # type: ignore[return-value]


def _paginate_from_statement(query: FromStatement, params: AbstractParams) -> FromStatement:
query = query._generate() # type: ignore[attr-defined]
query.element = create_paginate_query(query.element, params)
return query


def create_paginate_query(query: Selectable, params: AbstractParams) -> Selectable:
if isinstance(query, TextClause):
return text(create_paginate_query_from_text(query.text, params))
if isinstance(query, FromStatement):
return _paginate_from_statement(query, params)

return generic_query_apply_params(query, params.to_raw_params().as_limit_offset())


def create_count_query(query: Selectable, *, use_subquery: bool = True) -> Selectable:
if isinstance(query, TextClause):
return text(create_count_query_from_text(query.text))
if isinstance(query, FromStatement):
return create_count_query(query.element)

query = query.order_by(None).options(noload("*"))

Expand Down
37 changes: 37 additions & 0 deletions tests/ext/test_sqlalchemy_from_statement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Iterator, Type

from fastapi import Depends, FastAPI
from pytest import fixture
from sqlalchemy import select, text
from sqlalchemy.orm.session import Session

from fastapi_pagination import LimitOffsetPage, Page, add_pagination
from fastapi_pagination.ext.sqlalchemy import paginate

from ..base import BasePaginationTestCase
from ..utils import OptionalLimitOffsetPage, OptionalPage


@fixture(scope="session")
def app(sa_user, sa_session: Type[Session], model_cls: Type[object]):
app = FastAPI()

def get_db() -> Iterator[Session]:
db = sa_session()
try:
yield db
finally:
db.close()

@app.get("/default", response_model=Page[model_cls])
@app.get("/limit-offset", response_model=LimitOffsetPage[model_cls])
@app.get("/optional/default", response_model=OptionalPage[model_cls])
@app.get("/optional/limit-offset", response_model=OptionalLimitOffsetPage[model_cls])
def route(db: Session = Depends(get_db)):
return paginate(db, select(sa_user).from_statement(text("SELECT * FROM users")))

return add_pagination(app)


class TestSQLAlchemyFromStatement(BasePaginationTestCase):
pagination_types = ["default", "optional"]

0 comments on commit 49734f5

Please sign in to comment.