diff --git a/examples/pagination_sqlalchemy.py b/examples/pagination_sqlalchemy.py index 53917383..0bef1427 100644 --- a/examples/pagination_sqlalchemy.py +++ b/examples/pagination_sqlalchemy.py @@ -5,7 +5,7 @@ from faker import Faker from fastapi import Depends, FastAPI from pydantic import BaseModel -from sqlalchemy import create_engine, select +from sqlalchemy import create_engine, select, text from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, Session, mapped_column from fastapi_pagination import LimitOffsetPage, Page, add_pagination @@ -74,8 +74,8 @@ def create_user(user_in: UserIn, db: Session = Depends(get_db)) -> UserOut: @app.get("/users/default", response_model=Page[UserOut]) @app.get("/users/limit-offset", response_model=LimitOffsetPage[UserOut]) def get_users(db: Session = Depends(get_db)) -> Any: - return paginate(db, select(User)) + return paginate(db, select(User).from_statement(text("""SELECT * FROM users"""))) if __name__ == "__main__": - uvicorn.run("pagination_sqlalchemy:app") + uvicorn.run(app) diff --git a/fastapi_pagination/ext/sqlalchemy.py b/fastapi_pagination/ext/sqlalchemy.py index 8df4b07f..18b9ea0f 100644 --- a/fastapi_pagination/ext/sqlalchemy.py +++ b/fastapi_pagination/ext/sqlalchemy.py @@ -16,7 +16,7 @@ from sqlalchemy import func, select, text from sqlalchemy.exc import InvalidRequestError -from sqlalchemy.orm import Query, Session, noload, scoped_session +from sqlalchemy.orm import FromStatement, Query, Session, noload, scoped_session from sqlalchemy.sql.elements import TextClause from typing_extensions import TypeAlias, deprecated @@ -61,7 +61,7 @@ def __init__(self, *_: Any, **__: Any) -> None: AsyncConn: TypeAlias = "Union[AsyncSession, AsyncConnection, async_scoped_session]" SyncConn: TypeAlias = "Union[Session, Connection, scoped_session]" -Selectable: TypeAlias = "Union[Select, TextClause]" +Selectable: TypeAlias = "Union[Select, TextClause, FromStatement]" def create_paginate_query_from_text(query: str, params: AbstractParams) -> str: @@ -89,9 +89,17 @@ def paginate_query(query: Select, params: AbstractParams) -> Select: return create_paginate_query(query, params) # type: ignore[return-value] +def _paginate_from_statement(query: FromStatement, params: AbstractParams) -> FromStatement: + query = query._generate() # type: ignore[attr-defined] + query.element = create_paginate_query(query.element, params) + return query + + def create_paginate_query(query: Selectable, params: AbstractParams) -> Selectable: if isinstance(query, TextClause): return text(create_paginate_query_from_text(query.text, params)) + if isinstance(query, FromStatement): + return _paginate_from_statement(query, params) return generic_query_apply_params(query, params.to_raw_params().as_limit_offset()) @@ -99,6 +107,8 @@ def create_paginate_query(query: Selectable, params: AbstractParams) -> Selectab def create_count_query(query: Selectable, *, use_subquery: bool = True) -> Selectable: if isinstance(query, TextClause): return text(create_count_query_from_text(query.text)) + if isinstance(query, FromStatement): + return create_count_query(query.element) query = query.order_by(None).options(noload("*")) diff --git a/tests/ext/test_sqlalchemy_from_statement.py b/tests/ext/test_sqlalchemy_from_statement.py new file mode 100644 index 00000000..751e059d --- /dev/null +++ b/tests/ext/test_sqlalchemy_from_statement.py @@ -0,0 +1,37 @@ +from typing import Iterator, Type + +from fastapi import Depends, FastAPI +from pytest import fixture +from sqlalchemy import select, text +from sqlalchemy.orm.session import Session + +from fastapi_pagination import LimitOffsetPage, Page, add_pagination +from fastapi_pagination.ext.sqlalchemy import paginate + +from ..base import BasePaginationTestCase +from ..utils import OptionalLimitOffsetPage, OptionalPage + + +@fixture(scope="session") +def app(sa_user, sa_session: Type[Session], model_cls: Type[object]): + app = FastAPI() + + def get_db() -> Iterator[Session]: + db = sa_session() + try: + yield db + finally: + db.close() + + @app.get("/default", response_model=Page[model_cls]) + @app.get("/limit-offset", response_model=LimitOffsetPage[model_cls]) + @app.get("/optional/default", response_model=OptionalPage[model_cls]) + @app.get("/optional/limit-offset", response_model=OptionalLimitOffsetPage[model_cls]) + def route(db: Session = Depends(get_db)): + return paginate(db, select(sa_user).from_statement(text("SELECT * FROM users"))) + + return add_pagination(app) + + +class TestSQLAlchemyFromStatement(BasePaginationTestCase): + pagination_types = ["default", "optional"]