Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

feat(repository)!: count() & list_and_count() methods #276

Open
wants to merge 11 commits into
base: 0.30
Choose a base branch
from
32 changes: 30 additions & 2 deletions src/starlite_saqlalchemy/repository/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from starlite_saqlalchemy.exceptions import NotFoundError

if TYPE_CHECKING:
from collections.abc import Sequence

from .types import FilterTypes

__all__ = ["AbstractRepository"]

T = TypeVar("T")
CollectionT = TypeVar("CollectionT")
RepoT = TypeVar("RepoT", bound="AbstractRepository")


Expand All @@ -38,6 +41,18 @@ async def add(self, data: T) -> T:
The added instance.
"""

@abstractmethod
async def count(self, *filters: FilterTypes, **kwargs: Any) -> int:
"""Get the count of records returned by a query.

Args:
*filters: Types for specific filtering operations.
**kwargs: Instance attribute value filters.

Returns:
The count of instances
"""

@abstractmethod
async def delete(self, id_: Any) -> T:
"""Delete instance identified by `id_`.
Expand Down Expand Up @@ -67,7 +82,7 @@ async def get(self, id_: Any) -> T:
"""

@abstractmethod
async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]:
async def list(self, *filters: FilterTypes, **kwargs: Any) -> Sequence[T]:
"""Get a list of instances, optionally filtered.

Args:
Expand All @@ -78,6 +93,18 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]:
The list of instances, after filtering applied.
"""

@abstractmethod
async def list_and_count(self, *filters: FilterTypes, **kwargs: Any) -> tuple[Sequence[T], int]:
"""

Args:
*filters: Types for specific filtering operations.
**kwargs: Instance attribute value filters.

Returns:
Count of records returned by query, ignoring pagination.
"""

@abstractmethod
async def update(self, data: T) -> T:
"""Update instance with the attribute values present on `data`.
Expand Down Expand Up @@ -113,12 +140,13 @@ async def upsert(self, data: T) -> T:
"""

@abstractmethod
def filter_collection_by_kwargs(self, **kwargs: Any) -> None:
def filter_collection_by_kwargs(self, collection: CollectionT, /, **kwargs: Any) -> CollectionT:
"""Filter the collection by kwargs.

Has `AND` semantics where multiple kwargs name/value pairs are provided.

Args:
collection: the collection to be filtered
**kwargs: key/value pairs such that objects remaining in the collection after filtering
have the property that their attribute named `key` has value equal to `value`.

Expand Down
163 changes: 126 additions & 37 deletions src/starlite_saqlalchemy/repository/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast

from sqlalchemy import select, text
from sqlalchemy import over, select, text
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.sql import func as sql_func

from starlite_saqlalchemy.exceptions import ConflictError, StarliteSaqlalchemyError
from starlite_saqlalchemy.repository.abc import AbstractRepository
Expand Down Expand Up @@ -33,7 +34,9 @@

T = TypeVar("T")
ModelT = TypeVar("ModelT", bound="orm.Base | orm.AuditBase")
RowT = TypeVar("RowT", bound=tuple[Any, ...])
SQLARepoT = TypeVar("SQLARepoT", bound="SQLAlchemyRepository")
SelectT = TypeVar("SelectT", bound="Select[Any]")


@contextmanager
Expand All @@ -60,17 +63,13 @@ def wrap_sqlalchemy_exception() -> Any:
class SQLAlchemyRepository(AbstractRepository[ModelT], Generic[ModelT]):
"""SQLAlchemy based implementation of the repository interface."""

def __init__(
self, *, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None, **kwargs: Any
) -> None:
def __init__(self, *, session: AsyncSession, **kwargs: Any) -> None:
"""
Args:
session: Session managing the unit-of-work for the operation.
select_: To facilitate customization of the underlying select query.
"""
super().__init__(**kwargs)
self.session = session
self._select = select(self.model_type) if select_ is None else select_

async def add(self, data: ModelT) -> ModelT:
"""Add `data` to the collection.
Expand All @@ -88,6 +87,33 @@ async def add(self, data: ModelT) -> ModelT:
self.session.expunge(instance)
return instance

async def count(self, *filters: FilterTypes, **kwargs: Any) -> int:
"""

Args:
*filters: Types for specific filtering operations.
**kwargs: Instance attribute value filters.

Returns:
Count of records returned by query, ignoring pagination.
"""
select_ = select(sql_func.count(self.model_type.id)) # type:ignore[attr-defined]
for filter_ in filters:
match filter_:
case LimitOffset(_, _):
pass
# we do not apply this filter to the count since we need the total rows
case BeforeAfter(field_name, before, after):
select_ = self._filter_on_datetime_field(
field_name, before, after, select_=select_
)
case CollectionFilter(field_name, values):
select_ = self._filter_in_collection(field_name, values, select_=select_)
case _:
raise StarliteSaqlalchemyError(f"Unexpected filter: {filter}")
results = await self._execute(select_)
return results.scalar_one() # type: ignore[no-any-return]

async def delete(self, id_: Any) -> ModelT:
"""Delete instance identified by `id_`.

Expand Down Expand Up @@ -119,14 +145,15 @@ async def get(self, id_: Any) -> ModelT:
Raises:
RepositoryNotFoundException: If no instance found identified by `id_`.
"""
select_ = self._create_select_for_model()
with wrap_sqlalchemy_exception():
self._filter_select_by_kwargs(**{self.id_attribute: id_})
instance = (await self._execute()).scalar_one_or_none()
select_ = self._filter_select_by_kwargs(select_, **{self.id_attribute: id_})
instance = (await self._execute(select_)).scalar_one_or_none()
instance = self.check_not_found(instance)
self.session.expunge(instance)
return instance

async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]:
async def list(self, *filters: FilterTypes, **kwargs: Any) -> abc.Sequence[ModelT]:
"""Get a list of instances, optionally filtered.

Args:
Expand All @@ -136,25 +163,46 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]:
Returns:
The list of instances, after filtering applied.
"""
for filter_ in filters:
match filter_:
case LimitOffset(limit, offset):
self._apply_limit_offset_pagination(limit, offset)
case BeforeAfter(field_name, before, after):
self._filter_on_datetime_field(field_name, before, after)
case CollectionFilter(field_name, values):
self._filter_in_collection(field_name, values)
case _:
raise StarliteSaqlalchemyError(f"Unexpected filter: {filter}")
self._filter_select_by_kwargs(**kwargs)
select_ = self._create_select_for_model()
select_ = self._filter_for_list(*filters, select_=select_)
select_ = self._filter_select_by_kwargs(select_, **kwargs)

with wrap_sqlalchemy_exception():
result = await self._execute()
result = await self._execute(select_)
instances = list(result.scalars())
for instance in instances:
self.session.expunge(instance)
return instances

async def list_and_count(
self, *filters: FilterTypes, **kwargs: Any
) -> tuple[abc.Sequence[ModelT], int]:
"""

Args:
*filters: Types for specific filtering operations.
**kwargs: Instance attribute value filters.

Returns:
Count of records returned by query, ignoring pagination.
"""
select_ = select(
self.model_type,
over(sql_func.count(self.model_type.id)), # type:ignore[attr-defined]
peterschutt marked this conversation as resolved.
Show resolved Hide resolved
)
select_ = self._filter_for_list(*filters, select_=select_)
select_ = self._filter_select_by_kwargs(select_, **kwargs)
with wrap_sqlalchemy_exception():
result = await self._execute(select_)
count: int = 0
instances: list[ModelT] = []
for i, (instance, count_value) in enumerate(result):
self.session.expunge(instance)
instances.append(instance)
if i == 0:
count = count_value
return instances, count

async def update(self, data: ModelT) -> ModelT:
"""Update instance with the attribute values present on `data`.

Expand Down Expand Up @@ -203,15 +251,21 @@ async def upsert(self, data: ModelT) -> ModelT:
self.session.expunge(instance)
return instance

def filter_collection_by_kwargs(self, **kwargs: Any) -> None:
def filter_collection_by_kwargs( # type:ignore[override]
self,
collection: SelectT,
/,
**kwargs: Any,
) -> SelectT:
"""Filter the collection by kwargs.

Args:
collection: select to filter
**kwargs: key/value pairs such that objects remaining in the collection after filtering
have the property that their attribute named `key` has value equal to `value`.
"""
with wrap_sqlalchemy_exception():
self._select.filter_by(**kwargs)
return collection.filter_by(**kwargs)

@classmethod
async def check_health(cls, session: AsyncSession) -> bool:
Expand All @@ -229,8 +283,10 @@ async def check_health(cls, session: AsyncSession) -> bool:

# the following is all sqlalchemy implementation detail, and shouldn't be directly accessed

def _apply_limit_offset_pagination(self, limit: int, offset: int) -> None:
self._select = self._select.limit(limit).offset(offset)
def _apply_limit_offset_pagination(
self, limit: int, offset: int, *, select_: SelectT
) -> SelectT:
return select_.limit(limit).offset(offset)

async def _attach_to_session(
self, model: ModelT, strategy: Literal["add", "merge"] = "add"
Expand All @@ -256,23 +312,56 @@ async def _attach_to_session(
case _:
raise ValueError("Unexpected value for `strategy`, must be `'add'` or `'merge'`")

async def _execute(self) -> Result[tuple[ModelT, ...]]:
return await self.session.execute(self._select)
def _create_select_for_model(self) -> Select[tuple[ModelT]]:
return select(self.model_type)

async def _execute(self, select_: Select[RowT]) -> Result[RowT]:
return cast("Result[RowT]", await self.session.execute(select_))

def _filter_for_list(self, *filters: FilterTypes, select_: SelectT) -> SelectT:
"""
Args:
*filters: filter types to apply to the query

Keyword Args:
select_: select to apply filters against

Returns:
The select with filters applied.
"""
for filter_ in filters:
match filter_:
case LimitOffset(limit, offset):
select_ = self._apply_limit_offset_pagination(limit, offset, select_=select_)
case BeforeAfter(field_name, before, after):
select_ = self._filter_on_datetime_field(
field_name, before, after, select_=select_
)
case CollectionFilter(field_name, values):
select_ = self._filter_in_collection(field_name, values, select_=select_)
case _:
raise StarliteSaqlalchemyError(f"Unexpected filter: {filter}")
return select_

def _filter_in_collection(self, field_name: str, values: abc.Collection[Any]) -> None:
def _filter_in_collection(
self, field_name: str, values: abc.Collection[Any], *, select_: SelectT
) -> SelectT:
if not values:
return
self._select = self._select.where(getattr(self.model_type, field_name).in_(values))
return select_

return select_.where(getattr(self.model_type, field_name).in_(values))

def _filter_on_datetime_field(
self, field_name: str, before: datetime | None, after: datetime | None
) -> None:
self, field_name: str, before: datetime | None, after: datetime | None, *, select_: SelectT
) -> SelectT:
field = getattr(self.model_type, field_name)
if before is not None:
self._select = self._select.where(field < before)
select_ = select_.where(field < before)
if after is not None:
self._select = self._select.where(field > before)
return select_.where(field > before)
return select_

def _filter_select_by_kwargs(self, **kwargs: Any) -> None:
def _filter_select_by_kwargs(self, select_: SelectT, **kwargs: Any) -> SelectT:
for key, val in kwargs.items():
self._select = self._select.where(getattr(self.model_type, key) == val)
select_ = select_.where(getattr(self.model_type, key) == val)
return select_
26 changes: 24 additions & 2 deletions src/starlite_saqlalchemy/service/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
"""
from __future__ import annotations

import asyncio
import contextlib
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar

from starlite_saqlalchemy import constants
from starlite_saqlalchemy.exceptions import NotFoundError

if TYPE_CHECKING:
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Sequence


T = TypeVar("T")
Expand Down Expand Up @@ -42,6 +43,16 @@ def __init_subclass__(cls, *_: Any, **__: Any) -> None:

# pylint:disable=unused-argument

async def count(self, **kwargs: Any) -> int:
"""
Args:
**kwargs: key value pairs of filter types.

Returns:
A count of the collection, filtered, but ignoring pagination.
"""
return 0

async def create(self, data: T) -> T:
"""Create an instance of `T`.

Expand All @@ -53,7 +64,7 @@ async def create(self, data: T) -> T:
"""
return data

async def list(self, **kwargs: Any) -> list[T]:
async def list(self, **kwargs: Any) -> Sequence[T]:
"""Return view of the collection of `T`.

Args:
Expand All @@ -64,6 +75,17 @@ async def list(self, **kwargs: Any) -> list[T]:
"""
return []

async def list_and_count(self, **kwargs: Any) -> tuple[Sequence[T], int]:
"""
Args:
**kwargs: Keyword arguments for filtering.

Returns:
List of instances and count of total collection, ignoring pagination.
"""
collection, count = await asyncio.gather(self.list(**kwargs), self.count(**kwargs))
peterschutt marked this conversation as resolved.
Show resolved Hide resolved
return collection, count

async def update(self, id_: Any, data: T) -> T:
"""Update existing instance of `T` with `data`.

Expand Down
Loading