Skip to content

Commit

Permalink
Add better logic for scalars unwrapping (#1281)
Browse files Browse the repository at this point in the history
  • Loading branch information
uriyyo committed Sep 13, 2024
1 parent d7be9b4 commit 384d761
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/pagination_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from faker import Faker
from fastapi import Depends, FastAPI
from pydantic import BaseModel
from sqlalchemy import create_engine, select, text
from sqlalchemy import create_engine, select
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, Session, mapped_column

from fastapi_pagination import LimitOffsetPage, Page, add_pagination
Expand Down Expand Up @@ -74,7 +74,7 @@ def create_user(user_in: UserIn, db: Session = Depends(get_db)) -> UserOut:
@app.get("/users/default", response_model=Page[UserOut])
@app.get("/users/limit-offset", response_model=LimitOffsetPage[UserOut])
def get_users(db: Session = Depends(get_db)) -> Any:
return paginate(db, select(User).from_statement(text("""SELECT * FROM users""")))
return paginate(db, select(User))


if __name__ == "__main__":
Expand Down
17 changes: 14 additions & 3 deletions fastapi_pagination/ext/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import FromStatement, Query, Session, noload, scoped_session
from sqlalchemy.sql.elements import TextClause
from typing_extensions import TypeAlias, deprecated
from typing_extensions import TypeAlias, deprecated, no_type_check

from ..api import apply_items_transformer, create_page
from ..bases import AbstractPage, AbstractParams, is_cursor
Expand Down Expand Up @@ -64,6 +64,14 @@ def __init__(self, *_: Any, **__: Any) -> None:
Selectable: TypeAlias = "Union[Select, TextClause, FromStatement]"


@no_type_check
def _should_unwrap_scalars(query: Selectable) -> bool:
try:
return len(query.column_descriptions) == 1 and len(query.columns) > 1
except AttributeError:
return True


def create_paginate_query_from_text(query: str, params: AbstractParams) -> str:
raw_params = params.to_raw_params().as_limit_offset()

Expand Down Expand Up @@ -174,7 +182,9 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any:
per_page=raw_params.size,
page=raw_params.cursor, # type: ignore[arg-type]
)
items = unwrap_scalars([*page])
items = [*page]
if _should_unwrap_scalars(query):
items = unwrap_scalars(items)
items = _apply_items_transformer(items, transformer)

return create_page(
Expand All @@ -190,7 +200,8 @@ def _apply_items_transformer(*args: Any, **kwargs: Any) -> Any:

query = create_paginate_query(query, params)
items = _maybe_unique(conn.execute(query), unique)
items = unwrap_scalars(items)
if _should_unwrap_scalars(query):
items = unwrap_scalars(items)
items = _apply_items_transformer(items, transformer)

return create_page(
Expand Down

0 comments on commit 384d761

Please sign in to comment.