From 65e6bfecc654718e42b627d98f8657fc1cae9810 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Fri, 27 Jan 2023 08:03:33 +1000 Subject: [PATCH 01/11] build(deps): pin to starlite <= 1.50.0 --- poetry.lock | 10 +++++----- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index 307736b..ed12a85 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1042,14 +1042,14 @@ sqlcipher = ["sqlcipher3-binary"] [[package]] name = "starlite" -version = "1.50.1" +version = "1.50.0" description = "Performant, light and flexible ASGI API Framework" category = "main" optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "starlite-1.50.1-py3-none-any.whl", hash = "sha256:31d45efc805b895b9733b58825942129e55ebbbf8ddaeaf931b1237af5daae37"}, - {file = "starlite-1.50.1.tar.gz", hash = "sha256:7cc30ace31c47c406666baff839e4dfa24e205855ef0fada011b9e1bbadda852"}, + {file = "starlite-1.50.0-py3-none-any.whl", hash = "sha256:9d3379dec91fdfa0fb4a2cb77619e41455b89f280258f253de361f2f94464313"}, + {file = "starlite-1.50.0.tar.gz", hash = "sha256:6ed8572252cbb60cf70909cfb16721c0cc59d95c9fb5eaf8a4b0c7e8cfd68313"}, ] [package.dependencies] @@ -1070,7 +1070,7 @@ brotli = ["brotli"] cli = ["click", "jsbeautifier", "rich (>=13.0.0)"] cryptography = ["cryptography"] full = ["aiomcache", "brotli", "click", "cryptography", "jinja2 (>=3.1.2)", "opentelemetry-instrumentation-asgi", "picologging", "python-jose", "redis[hiredis]", "rich (>=13.0.0)", "structlog"] -jinja = ["jinja2 (>=3.1.2)"] +jina = ["jinja2 (>=3.1.2)"] jwt = ["cryptography", "python-jose"] memcached = ["aiomcache"] opentelemetry = ["opentelemetry-instrumentation-asgi"] @@ -1215,4 +1215,4 @@ worker = ["saq", "hiredis"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "7a5ed929059b3c670d3a4d40e6ac7cf871900e33ff937dd1ff34e615b36f56c1" +content-hash = "f614b5407e536db1af6f68bee978b170304c642eee949564a23604729b90c167" diff --git a/pyproject.toml b/pyproject.toml index 7a22583..42566e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ httpx = "*" msgspec = "*" pydantic = "*" python-dotenv = "*" -starlite = ">=1.40.1,<1.50.2" +starlite = ">=1.40.1,<=1.50.0" tenacity = "*" uvicorn = "*" uvloop = "*" From 14037c8085d93d871110b96bd85d472a7a8cc8cb Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 23 Jan 2023 21:50:14 -0600 Subject: [PATCH 02/11] feat: implements abstract count method and a paginated list that includes total rows --- src/starlite_saqlalchemy/repository/abc.py | 13 ++++++++++++ .../repository/sqlalchemy.py | 19 +++++++++++++++++ .../service/sqlalchemy.py | 21 +++++++++++++++++++ .../repository/test_sqlalchemy.py | 13 ++++++++++++ 4 files changed, 66 insertions(+) diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 3f76e68..6411156 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -52,6 +52,19 @@ async def delete(self, id_: Any) -> T: RepositoryNotFoundException: If no instance found identified by `id_`. """ + @abstractmethod + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """Get the count of records returned by a query. Optionally filtered. + + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + The count of instances + """ + @abstractmethod async def get(self, id_: Any) -> T: """Get instance identified by `id_`. diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index 204743f..b3e545a 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar from sqlalchemy import select, text +from sqlalchemy.sql import func as sql_func from sqlalchemy.exc import IntegrityError, SQLAlchemyError from starlite_saqlalchemy.exceptions import ConflictError, StarliteSaqlalchemyError @@ -203,6 +204,24 @@ async def upsert(self, data: ModelT) -> ModelT: self.session.expunge(instance) return instance + async def count(self, select_: Select[tuple[ModelT]] | None = None) -> int: + """Count records returned by query. + + Args: + select_ (Select | None): Optional SQL statement to generate a count statement. Defaults to [self._select] + + Returns: + int: _description_ + """ + if select_ is None: + select_ = self._select + count_statement = select_.with_only_columns( + sql_func.count(), + maintain_column_froms=True, + ).order_by(None) + results = await self.session.execute(count_statement) + return results.scalar_one() # type: ignore + def filter_collection_by_kwargs(self, **kwargs: Any) -> None: """Filter the collection by kwargs. diff --git a/src/starlite_saqlalchemy/service/sqlalchemy.py b/src/starlite_saqlalchemy/service/sqlalchemy.py index 5ede194..43259c1 100644 --- a/src/starlite_saqlalchemy/service/sqlalchemy.py +++ b/src/starlite_saqlalchemy/service/sqlalchemy.py @@ -122,3 +122,24 @@ async def new(cls: type[RepoServiceT]) -> AsyncIterator[RepoServiceT]: """ async with async_session_factory() as session: yield cls(session=session) + + +class PaginatedRepositoryService(RepositoryService): + """Paginated Service object that operates on a repository object.""" + + __id__ = "starlite_saqlalchemy.service.sqlalchemy.PaginatedRepositoryService" + + async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[ModelT], int]: + """Wrap repository scalars operation. + + Args: + *filters: Collection route filters. + **kwargs: Keyword arguments for attribute based filtering. + + Returns: + The list of instances retrieved from the repository. + """ + return ( + await self.repository.list(*filters, **kwargs), + await self.repository.count(*filters, **kwargs), + ) diff --git a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py index d93d18c..83c72a7 100644 --- a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py +++ b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py @@ -119,6 +119,19 @@ async def test_sqlalchemy_repo_list_with_pagination( mock_repo._select.limit().offset.assert_called_once_with(3) # type:ignore[call-arg] +async def test_sqlalchemy_repo_count( + mock_repo: SQLAlchemyRepository, monkeypatch: MonkeyPatch +) -> None: + """Test count operation with pagination.""" + mock_instances = [MagicMock(), MagicMock()] + result_mock = MagicMock() + result_mock.scalars = MagicMock(len(mock_instances)) + execute_mock = AsyncMock(return_value=result_mock) + monkeypatch.setattr(mock_repo, "_execute", execute_mock) + instance_count = await mock_repo.count() + assert instance_count == len(mock_instances) + + async def test_sqlalchemy_repo_list_with_before_after_filter( mock_repo: SQLAlchemyRepository, monkeypatch: MonkeyPatch ) -> None: From 00dd3f748dc0e3ac61274446b92b3d68c9937a76 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 24 Jan 2023 13:53:57 -0600 Subject: [PATCH 03/11] feat: sqlalchemy pagination with total row count --- src/starlite_saqlalchemy/repository/abc.py | 11 +++------ .../repository/sqlalchemy.py | 12 +++++++--- .../service/sqlalchemy.py | 23 +------------------ .../repository/test_sqlalchemy.py | 18 +++++++++++---- 4 files changed, 26 insertions(+), 38 deletions(-) diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 6411156..cba75d8 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -53,13 +53,8 @@ async def delete(self, id_: Any) -> T: """ @abstractmethod - async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: - """Get the count of records returned by a query. Optionally filtered. - - - Args: - *filters: Types for specific filtering operations. - **kwargs: Instance attribute value filters. + async def count(self, select_: Any | None = None) -> int: + """Get the count of records returned by a query. Returns: The count of instances @@ -80,7 +75,7 @@ async def get(self, id_: Any) -> T: """ @abstractmethod - async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]: + async def list(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[T], int]: """Get a list of instances, optionally filtered. Args: diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index b3e545a..ab7c55f 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -127,7 +127,7 @@ async def get(self, id_: Any) -> ModelT: self.session.expunge(instance) return instance - async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]: + async def list(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[ModelT], int]: """Get a list of instances, optionally filtered. Args: @@ -137,24 +137,30 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]: Returns: The list of instances, after filtering applied. """ + count_query = self._select for filter_ in filters: - match filter_: + match filter_: # noqa: E999 case LimitOffset(limit, offset): self._apply_limit_offset_pagination(limit, offset) + # we do not apply this filter to the count since we need the total rows case BeforeAfter(field_name, before, after): self._filter_on_datetime_field(field_name, before, after) + count_query = self._select case CollectionFilter(field_name, values): self._filter_in_collection(field_name, values) + count_query = self._select case _: raise StarliteSaqlalchemyError(f"Unexpected filter: {filter}") self._filter_select_by_kwargs(**kwargs) + count_query = self._select with wrap_sqlalchemy_exception(): result = await self._execute() + count = await self.count(count_query) instances = list(result.scalars()) for instance in instances: self.session.expunge(instance) - return instances + return instances, count async def update(self, data: ModelT) -> ModelT: """Update instance with the attribute values present on `data`. diff --git a/src/starlite_saqlalchemy/service/sqlalchemy.py b/src/starlite_saqlalchemy/service/sqlalchemy.py index 43259c1..ae18fef 100644 --- a/src/starlite_saqlalchemy/service/sqlalchemy.py +++ b/src/starlite_saqlalchemy/service/sqlalchemy.py @@ -50,7 +50,7 @@ async def create(self, data: ModelT) -> ModelT: """ return await self.repository.add(data) - async def list(self, *filters: "FilterTypes", **kwargs: Any) -> list[ModelT]: + async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[ModelT], int]: """Wrap repository scalars operation. Args: @@ -122,24 +122,3 @@ async def new(cls: type[RepoServiceT]) -> AsyncIterator[RepoServiceT]: """ async with async_session_factory() as session: yield cls(session=session) - - -class PaginatedRepositoryService(RepositoryService): - """Paginated Service object that operates on a repository object.""" - - __id__ = "starlite_saqlalchemy.service.sqlalchemy.PaginatedRepositoryService" - - async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[ModelT], int]: - """Wrap repository scalars operation. - - Args: - *filters: Collection route filters. - **kwargs: Keyword arguments for attribute based filtering. - - Returns: - The list of instances retrieved from the repository. - """ - return ( - await self.repository.list(*filters, **kwargs), - await self.repository.count(*filters, **kwargs), - ) diff --git a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py index 83c72a7..f30cb61 100644 --- a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py +++ b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import TYPE_CHECKING +from unittest import mock from unittest.mock import AsyncMock, MagicMock, call import pytest @@ -97,10 +98,15 @@ async def test_sqlalchemy_repo_list( mock_instances = [MagicMock(), MagicMock()] result_mock = MagicMock() result_mock.scalars = MagicMock(return_value=mock_instances) + count_mock = MagicMock() + count_mock.return_value = 2 execute_mock = AsyncMock(return_value=result_mock) + execute_count_mock = AsyncMock(return_value=count_mock) + monkeypatch.setattr(mock_repo, "count", execute_count_mock) monkeypatch.setattr(mock_repo, "_execute", execute_mock) - instances = await mock_repo.list() + instances, count = await mock_repo.list() assert instances == mock_instances + assert count == count_mock mock_repo.session.expunge.assert_has_calls(*mock_instances) mock_repo.session.commit.assert_not_called() @@ -123,13 +129,15 @@ async def test_sqlalchemy_repo_count( mock_repo: SQLAlchemyRepository, monkeypatch: MonkeyPatch ) -> None: """Test count operation with pagination.""" - mock_instances = [MagicMock(), MagicMock()] result_mock = MagicMock() - result_mock.scalars = MagicMock(len(mock_instances)) + count_mock = MagicMock() execute_mock = AsyncMock(return_value=result_mock) + execute_count_mock = AsyncMock(return_value=count_mock) + monkeypatch.setattr(mock_repo, "count", execute_count_mock) monkeypatch.setattr(mock_repo, "_execute", execute_mock) - instance_count = await mock_repo.count() - assert instance_count == len(mock_instances) + mock_repo.count.return_value = 1 + count = await mock_repo.count() + assert count == 1 async def test_sqlalchemy_repo_list_with_before_after_filter( From 23d8b362ab051cbbb8d1cf7aa4d9fa259de29eee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Jan 2023 03:56:33 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/starlite_saqlalchemy/repository/abc.py | 10 +++++++--- src/starlite_saqlalchemy/repository/sqlalchemy.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index cba75d8..1ffde39 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -53,8 +53,12 @@ async def delete(self, id_: Any) -> T: """ @abstractmethod - async def count(self, select_: Any | None = None) -> int: - """Get the count of records returned by a query. + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """Get the count of records returned by a query. Optionally filtered. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. Returns: The count of instances @@ -75,7 +79,7 @@ async def get(self, id_: Any) -> T: """ @abstractmethod - async def list(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[T], int]: + async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]: """Get a list of instances, optionally filtered. Args: diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index ab7c55f..b3d0d16 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar from sqlalchemy import select, text -from sqlalchemy.sql import func as sql_func from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.sql import func as sql_func from starlite_saqlalchemy.exceptions import ConflictError, StarliteSaqlalchemyError from starlite_saqlalchemy.repository.abc import AbstractRepository From f8b3798dbe6da9f59bae68da2aa6cb3aea2911c2 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 24 Jan 2023 13:56:13 -0600 Subject: [PATCH 05/11] fix: linting updates --- src/starlite_saqlalchemy/repository/sqlalchemy.py | 2 +- tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index b3d0d16..a11c556 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -139,7 +139,7 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[ModelT] """ count_query = self._select for filter_ in filters: - match filter_: # noqa: E999 + match filter_: case LimitOffset(limit, offset): self._apply_limit_offset_pagination(limit, offset) # we do not apply this filter to the count since we need the total rows diff --git a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py index f30cb61..1d34de8 100644 --- a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py +++ b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py @@ -4,7 +4,6 @@ from datetime import datetime from typing import TYPE_CHECKING -from unittest import mock from unittest.mock import AsyncMock, MagicMock, call import pytest From 40de281a17969228bcf47b9ad3178247b26f2200 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 24 Jan 2023 14:12:29 -0600 Subject: [PATCH 06/11] fix: implement generic pattern for services & repo --- src/starlite_saqlalchemy/repository/abc.py | 3 +++ src/starlite_saqlalchemy/service/generic.py | 4 ++-- .../testing/generic_mock_repository.py | 16 ++++++++++++++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 1ffde39..20b96cc 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -60,6 +60,9 @@ async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: *filters: Types for specific filtering operations. **kwargs: Instance attribute value filters. + Args: + select_: Optional statement to use for counting. + Returns: The count of instances """ diff --git a/src/starlite_saqlalchemy/service/generic.py b/src/starlite_saqlalchemy/service/generic.py index 3390c2f..d733532 100644 --- a/src/starlite_saqlalchemy/service/generic.py +++ b/src/starlite_saqlalchemy/service/generic.py @@ -53,7 +53,7 @@ async def create(self, data: T) -> T: """ return data - async def list(self, **kwargs: Any) -> list[T]: + async def list(self, **kwargs: Any) -> tuple[list[T], int]: """Return view of the collection of `T`. Args: @@ -62,7 +62,7 @@ async def list(self, **kwargs: Any) -> list[T]: Returns: The list of instances retrieved from the repository. """ - return [] + return [], 0 async def update(self, id_: Any, data: T) -> T: """Update existing instance of `T` with `data`. diff --git a/src/starlite_saqlalchemy/testing/generic_mock_repository.py b/src/starlite_saqlalchemy/testing/generic_mock_repository.py index b598e76..18bf09e 100644 --- a/src/starlite_saqlalchemy/testing/generic_mock_repository.py +++ b/src/starlite_saqlalchemy/testing/generic_mock_repository.py @@ -102,7 +102,7 @@ async def get(self, id_: Any) -> ModelT: """ return self._find_or_raise_not_found(id_) - async def list(self, *filters: "FilterTypes", **kwargs: Any) -> list[ModelT]: + async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[ModelT], int]: """Get a list of instances, optionally filtered. Args: @@ -112,7 +112,19 @@ async def list(self, *filters: "FilterTypes", **kwargs: Any) -> list[ModelT]: Returns: The list of instances, after filtering applied. """ - return list(self.collection.values()) + return list(self.collection.values()), len(list(self.collection.values())) + + async def count(self, select_: Any | None = None) -> int: + """Get a list of instances, optionally filtered. + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + The list of instances, after filtering applied. + """ + return len(list(self.collection.values())) async def update(self, data: ModelT) -> ModelT: """Update instance with the attribute values present on `data`. From b0e10f4e2dbf3518d1f55efa3c658679895a2350 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 24 Jan 2023 14:23:06 -0600 Subject: [PATCH 07/11] fix: test case correction --- tests/unit/require_sqlalchemy/test_service.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/unit/require_sqlalchemy/test_service.py b/tests/unit/require_sqlalchemy/test_service.py index 9aa77b2..14e33d6 100644 --- a/tests/unit/require_sqlalchemy/test_service.py +++ b/tests/unit/require_sqlalchemy/test_service.py @@ -36,14 +36,16 @@ async def test_service_create() -> None: async def test_service_list() -> None: """Test repository list action.""" - resp = await domain.authors.Service().list() - assert len(resp) == 2 + items, count = await domain.authors.Service().list() + assert isinstance(items, list) + assert count == 2 async def test_service_update() -> None: """Test repository update action.""" service_obj = domain.authors.Service() - author, _ = await service_obj.list() + authors, _ = await service_obj.list() + author = authors[0] assert author.name == "Agatha Christie" author.name = "different" resp = await service_obj.update(author.id, author) @@ -53,7 +55,8 @@ async def test_service_update() -> None: async def test_service_upsert_update() -> None: """Test repository upsert action for update.""" service_obj = domain.authors.Service() - author, _ = await service_obj.list() + authors, _ = await service_obj.list() + author = authors[0] assert author.name == "Agatha Christie" author.name = "different" resp = await service_obj.upsert(author.id, author) @@ -72,7 +75,8 @@ async def test_service_upsert_create() -> None: async def test_service_get() -> None: """Test repository get action.""" service_obj = domain.authors.Service() - author, _ = await service_obj.list() + authors, _ = await service_obj.list() + author = authors[0] retrieved = await service_obj.get(author.id) assert author is retrieved @@ -80,7 +84,8 @@ async def test_service_get() -> None: async def test_service_delete() -> None: """Test repository delete action.""" service_obj = domain.authors.Service() - author, _ = await service_obj.list() + authors, _ = await service_obj.list() + author = authors[0] deleted = await service_obj.delete(author.id) assert author is deleted @@ -96,7 +101,7 @@ async def test_service_method_default_behavior() -> None: service_obj = service.Service[object]() data = object() assert await service_obj.create(data) is data - assert await service_obj.list() == [] + assert await service_obj.list() == ([], 0) assert await service_obj.update("abc", data) is data assert await service_obj.upsert("abc", data) is data with pytest.raises(NotFoundError): From eeb13a34b968043136a9ba8f37ed58feb77f5bcb Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 25 Jan 2023 13:14:39 +1000 Subject: [PATCH 08/11] refactor(repository)!: looking for a better abstraction - decouple the collection (select statement) from the repo instance E.g., the sqlalchemy select that represents the collection shouldn't need to be bound to `self`. This means that any operation can use the utility methods,like `filter_collection_by_kwargs()` without having to worry about modifying a pre-constructed select to do what it wants. - make `list()` behave the same as it always did - add `count()` methods as standalone service and repo ops. - add `list_and_count()` methods as standalone service and repo ops. --- src/starlite_saqlalchemy/repository/abc.py | 47 +++-- .../repository/sqlalchemy.py | 185 ++++++++++++------ src/starlite_saqlalchemy/service/generic.py | 28 ++- .../service/sqlalchemy.py | 26 ++- .../testing/generic_mock_repository.py | 35 +++- tests/integration/conftest.py | 6 +- .../repository/test_sqlalchemy_repository.py | 21 +- .../test_generic_mock_repository.py | 8 +- .../repository/test_sqlalchemy.py | 60 +++--- tests/unit/require_sqlalchemy/test_service.py | 18 +- 10 files changed, 303 insertions(+), 131 deletions(-) diff --git a/src/starlite_saqlalchemy/repository/abc.py b/src/starlite_saqlalchemy/repository/abc.py index 20b96cc..d51fa71 100644 --- a/src/starlite_saqlalchemy/repository/abc.py +++ b/src/starlite_saqlalchemy/repository/abc.py @@ -7,11 +7,14 @@ from starlite_saqlalchemy.exceptions import NotFoundError if TYPE_CHECKING: + from collections.abc import Sequence + from .types import FilterTypes __all__ = ["AbstractRepository"] T = TypeVar("T") +CollectionT = TypeVar("CollectionT") RepoT = TypeVar("RepoT", bound="AbstractRepository") @@ -39,32 +42,29 @@ async def add(self, data: T) -> T: """ @abstractmethod - async def delete(self, id_: Any) -> T: - """Delete instance identified by `id_`. + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """Get the count of records returned by a query. Args: - id_: Identifier of instance to be deleted. + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. Returns: - The deleted instance. - - Raises: - RepositoryNotFoundException: If no instance found identified by `id_`. + The count of instances """ @abstractmethod - async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: - """Get the count of records returned by a query. Optionally filtered. - - Args: - *filters: Types for specific filtering operations. - **kwargs: Instance attribute value filters. + async def delete(self, id_: Any) -> T: + """Delete instance identified by `id_`. Args: - select_: Optional statement to use for counting. + id_: Identifier of instance to be deleted. Returns: - The count of instances + The deleted instance. + + Raises: + RepositoryNotFoundException: If no instance found identified by `id_`. """ @abstractmethod @@ -82,7 +82,7 @@ async def get(self, id_: Any) -> T: """ @abstractmethod - async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]: + async def list(self, *filters: FilterTypes, **kwargs: Any) -> Sequence[T]: """Get a list of instances, optionally filtered. Args: @@ -93,6 +93,18 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[T]: The list of instances, after filtering applied. """ + @abstractmethod + async def list_and_count(self, *filters: FilterTypes, **kwargs: Any) -> tuple[Sequence[T], int]: + """ + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + Count of records returned by query, ignoring pagination. + """ + @abstractmethod async def update(self, data: T) -> T: """Update instance with the attribute values present on `data`. @@ -128,12 +140,13 @@ async def upsert(self, data: T) -> T: """ @abstractmethod - def filter_collection_by_kwargs(self, **kwargs: Any) -> None: + def filter_collection_by_kwargs(self, collection: CollectionT, /, **kwargs: Any) -> CollectionT: """Filter the collection by kwargs. Has `AND` semantics where multiple kwargs name/value pairs are provided. Args: + collection: the collection to be filtered **kwargs: key/value pairs such that objects remaining in the collection after filtering have the property that their attribute named `key` has value equal to `value`. diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index a11c556..f3b9ddf 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -2,9 +2,9 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast -from sqlalchemy import select, text +from sqlalchemy import over, select, text from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.sql import func as sql_func @@ -34,7 +34,9 @@ T = TypeVar("T") ModelT = TypeVar("ModelT", bound="orm.Base | orm.AuditBase") +RowT = TypeVar("RowT", bound=tuple[Any, ...]) SQLARepoT = TypeVar("SQLARepoT", bound="SQLAlchemyRepository") +SelectT = TypeVar("SelectT", bound="Select[Any]") @contextmanager @@ -61,9 +63,7 @@ def wrap_sqlalchemy_exception() -> Any: class SQLAlchemyRepository(AbstractRepository[ModelT], Generic[ModelT]): """SQLAlchemy based implementation of the repository interface.""" - def __init__( - self, *, session: AsyncSession, select_: Select[tuple[ModelT]] | None = None, **kwargs: Any - ) -> None: + def __init__(self, *, session: AsyncSession, **kwargs: Any) -> None: """ Args: session: Session managing the unit-of-work for the operation. @@ -71,7 +71,6 @@ def __init__( """ super().__init__(**kwargs) self.session = session - self._select = select(self.model_type) if select_ is None else select_ async def add(self, data: ModelT) -> ModelT: """Add `data` to the collection. @@ -89,6 +88,33 @@ async def add(self, data: ModelT) -> ModelT: self.session.expunge(instance) return instance + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """ + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + Count of records returned by query, ignoring pagination. + """ + select_ = select(sql_func.count(self.model_type.id)) # type:ignore[attr-defined] + for filter_ in filters: + match filter_: + case LimitOffset(_, _): + pass + # we do not apply this filter to the count since we need the total rows + case BeforeAfter(field_name, before, after): + select_ = self._filter_on_datetime_field( + field_name, before, after, select_=select_ + ) + case CollectionFilter(field_name, values): + select_ = self._filter_in_collection(field_name, values, select_=select_) + case _: + raise StarliteSaqlalchemyError(f"Unexpected filter: {filter}") + results = await self._execute(select_) + return results.scalar_one() # type: ignore[no-any-return] + async def delete(self, id_: Any) -> ModelT: """Delete instance identified by `id_`. @@ -120,14 +146,15 @@ async def get(self, id_: Any) -> ModelT: Raises: RepositoryNotFoundException: If no instance found identified by `id_`. """ + select_ = self._create_select_for_model() with wrap_sqlalchemy_exception(): - self._filter_select_by_kwargs(**{self.id_attribute: id_}) - instance = (await self._execute()).scalar_one_or_none() + select_ = self._filter_select_by_kwargs(select_, **{self.id_attribute: id_}) + instance = (await self._execute(select_)).scalar_one_or_none() instance = self.check_not_found(instance) self.session.expunge(instance) return instance - async def list(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[ModelT], int]: + async def list(self, *filters: FilterTypes, **kwargs: Any) -> abc.Sequence[ModelT]: """Get a list of instances, optionally filtered. Args: @@ -137,29 +164,44 @@ async def list(self, *filters: FilterTypes, **kwargs: Any) -> tuple[list[ModelT] Returns: The list of instances, after filtering applied. """ - count_query = self._select - for filter_ in filters: - match filter_: - case LimitOffset(limit, offset): - self._apply_limit_offset_pagination(limit, offset) - # we do not apply this filter to the count since we need the total rows - case BeforeAfter(field_name, before, after): - self._filter_on_datetime_field(field_name, before, after) - count_query = self._select - case CollectionFilter(field_name, values): - self._filter_in_collection(field_name, values) - count_query = self._select - case _: - raise StarliteSaqlalchemyError(f"Unexpected filter: {filter}") - self._filter_select_by_kwargs(**kwargs) - count_query = self._select + select_ = self._create_select_for_model() + select_ = self._filter_for_list(*filters, select_=select_) + select_ = self._filter_select_by_kwargs(select_, **kwargs) with wrap_sqlalchemy_exception(): - result = await self._execute() - count = await self.count(count_query) + result = await self._execute(select_) instances = list(result.scalars()) for instance in instances: self.session.expunge(instance) + return instances + + async def list_and_count( + self, *filters: FilterTypes, **kwargs: Any + ) -> tuple[abc.Sequence[ModelT], int]: + """ + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + Count of records returned by query, ignoring pagination. + """ + select_ = select( + self.model_type, + over(sql_func.count(self.model_type.id)), # type:ignore[attr-defined] + ) + select_ = self._filter_for_list(*filters, select_=select_) + select_ = self._filter_select_by_kwargs(select_, **kwargs) + with wrap_sqlalchemy_exception(): + result = await self._execute(select_) + count: int = 0 + instances: list[ModelT] = [] + for i, (instance, count_value) in enumerate(result): + self.session.expunge(instance) + instances.append(instance) + if i == 0: + count = count_value return instances, count async def update(self, data: ModelT) -> ModelT: @@ -210,33 +252,21 @@ async def upsert(self, data: ModelT) -> ModelT: self.session.expunge(instance) return instance - async def count(self, select_: Select[tuple[ModelT]] | None = None) -> int: - """Count records returned by query. - - Args: - select_ (Select | None): Optional SQL statement to generate a count statement. Defaults to [self._select] - - Returns: - int: _description_ - """ - if select_ is None: - select_ = self._select - count_statement = select_.with_only_columns( - sql_func.count(), - maintain_column_froms=True, - ).order_by(None) - results = await self.session.execute(count_statement) - return results.scalar_one() # type: ignore - - def filter_collection_by_kwargs(self, **kwargs: Any) -> None: + def filter_collection_by_kwargs( # type:ignore[override] + self, + collection: SelectT, + /, + **kwargs: Any, + ) -> SelectT: """Filter the collection by kwargs. Args: + collection: select to filter **kwargs: key/value pairs such that objects remaining in the collection after filtering have the property that their attribute named `key` has value equal to `value`. """ with wrap_sqlalchemy_exception(): - self._select.filter_by(**kwargs) + return collection.filter_by(**kwargs) @classmethod async def check_health(cls, session: AsyncSession) -> bool: @@ -254,8 +284,10 @@ async def check_health(cls, session: AsyncSession) -> bool: # the following is all sqlalchemy implementation detail, and shouldn't be directly accessed - def _apply_limit_offset_pagination(self, limit: int, offset: int) -> None: - self._select = self._select.limit(limit).offset(offset) + def _apply_limit_offset_pagination( + self, limit: int, offset: int, *, select_: SelectT + ) -> SelectT: + return select_.limit(limit).offset(offset) async def _attach_to_session( self, model: ModelT, strategy: Literal["add", "merge"] = "add" @@ -281,23 +313,56 @@ async def _attach_to_session( case _: raise ValueError("Unexpected value for `strategy`, must be `'add'` or `'merge'`") - async def _execute(self) -> Result[tuple[ModelT, ...]]: - return await self.session.execute(self._select) + def _create_select_for_model(self) -> Select[tuple[ModelT]]: + return select(self.model_type) + + async def _execute(self, select_: Select[RowT]) -> Result[RowT]: + return cast("Result[RowT]", await self.session.execute(select_)) + + def _filter_for_list(self, *filters: FilterTypes, select_: SelectT) -> SelectT: + """ + Args: + *filters: filter types to apply to the query + + Keyword Args: + select_: select to apply filters against + + Returns: + The select with filters applied. + """ + for filter_ in filters: + match filter_: + case LimitOffset(limit, offset): + select_ = self._apply_limit_offset_pagination(limit, offset, select_=select_) + case BeforeAfter(field_name, before, after): + select_ = self._filter_on_datetime_field( + field_name, before, after, select_=select_ + ) + case CollectionFilter(field_name, values): + select_ = self._filter_in_collection(field_name, values, select_=select_) + case _: + raise StarliteSaqlalchemyError(f"Unexpected filter: {filter}") + return select_ - def _filter_in_collection(self, field_name: str, values: abc.Collection[Any]) -> None: + def _filter_in_collection( + self, field_name: str, values: abc.Collection[Any], *, select_: SelectT + ) -> SelectT: if not values: - return - self._select = self._select.where(getattr(self.model_type, field_name).in_(values)) + return select_ + + return select_.where(getattr(self.model_type, field_name).in_(values)) def _filter_on_datetime_field( - self, field_name: str, before: datetime | None, after: datetime | None - ) -> None: + self, field_name: str, before: datetime | None, after: datetime | None, *, select_: SelectT + ) -> SelectT: field = getattr(self.model_type, field_name) if before is not None: - self._select = self._select.where(field < before) + select_ = select_.where(field < before) if after is not None: - self._select = self._select.where(field > before) + return select_.where(field > before) + return select_ - def _filter_select_by_kwargs(self, **kwargs: Any) -> None: + def _filter_select_by_kwargs(self, select_: SelectT, **kwargs: Any) -> SelectT: for key, val in kwargs.items(): - self._select = self._select.where(getattr(self.model_type, key) == val) + select_ = select_.where(getattr(self.model_type, key) == val) + return select_ diff --git a/src/starlite_saqlalchemy/service/generic.py b/src/starlite_saqlalchemy/service/generic.py index d733532..255f8eb 100644 --- a/src/starlite_saqlalchemy/service/generic.py +++ b/src/starlite_saqlalchemy/service/generic.py @@ -4,6 +4,7 @@ """ from __future__ import annotations +import asyncio import contextlib from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar @@ -11,7 +12,7 @@ from starlite_saqlalchemy.exceptions import NotFoundError if TYPE_CHECKING: - from collections.abc import AsyncIterator + from collections.abc import AsyncIterator, Sequence T = TypeVar("T") @@ -42,6 +43,16 @@ def __init_subclass__(cls, *_: Any, **__: Any) -> None: # pylint:disable=unused-argument + async def count(self, **kwargs: Any) -> int: + """ + Args: + **kwargs: key value pairs of filter types. + + Returns: + A count of the collection, filtered, but ignoring pagination. + """ + return 0 + async def create(self, data: T) -> T: """Create an instance of `T`. @@ -53,7 +64,7 @@ async def create(self, data: T) -> T: """ return data - async def list(self, **kwargs: Any) -> tuple[list[T], int]: + async def list(self, **kwargs: Any) -> Sequence[T]: """Return view of the collection of `T`. Args: @@ -62,7 +73,18 @@ async def list(self, **kwargs: Any) -> tuple[list[T], int]: Returns: The list of instances retrieved from the repository. """ - return [], 0 + return [] + + async def list_and_count(self, **kwargs: Any) -> tuple[Sequence[T], int]: + """ + Args: + **kwargs: Keyword arguments for filtering. + + Returns: + List of instances and count of total collection, ignoring pagination. + """ + collection, count = await asyncio.gather(self.list(**kwargs), self.count(**kwargs)) + return collection, count async def update(self, id_: Any, data: T) -> T: """Update existing instance of `T` with `data`. diff --git a/src/starlite_saqlalchemy/service/sqlalchemy.py b/src/starlite_saqlalchemy/service/sqlalchemy.py index ae18fef..de1a626 100644 --- a/src/starlite_saqlalchemy/service/sqlalchemy.py +++ b/src/starlite_saqlalchemy/service/sqlalchemy.py @@ -15,7 +15,7 @@ from .generic import Service if TYPE_CHECKING: - from collections.abc import AsyncIterator + from collections.abc import AsyncIterator, Sequence from starlite_saqlalchemy.repository.abc import AbstractRepository from starlite_saqlalchemy.repository.types import FilterTypes @@ -39,6 +39,16 @@ def __init__(self, **repo_kwargs: Any) -> None: """ self.repository = self.repository_type(**repo_kwargs) + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """ + Args: + **kwargs: key value pairs of filter types. + + Returns: + A count of the collection, filtered, but ignoring pagination. + """ + return await self.repository.count(*filters, **kwargs) + async def create(self, data: ModelT) -> ModelT: """Wrap repository instance creation. @@ -50,7 +60,7 @@ async def create(self, data: ModelT) -> ModelT: """ return await self.repository.add(data) - async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[ModelT], int]: + async def list(self, *filters: FilterTypes, **kwargs: Any) -> Sequence[ModelT]: """Wrap repository scalars operation. Args: @@ -62,6 +72,18 @@ async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[Model """ return await self.repository.list(*filters, **kwargs) + async def list_and_count( + self, *filters: FilterTypes, **kwargs: Any + ) -> tuple[Sequence[ModelT], int]: + """ + Args: + **kwargs: Keyword arguments for filtering. + + Returns: + List of instances and count of total collection, ignoring pagination. + """ + return await self.repository.list_and_count(*filters, **kwargs) + async def update(self, id_: Any, data: ModelT) -> ModelT: """Wrap repository update operation. diff --git a/src/starlite_saqlalchemy/testing/generic_mock_repository.py b/src/starlite_saqlalchemy/testing/generic_mock_repository.py index 18bf09e..66c0361 100644 --- a/src/starlite_saqlalchemy/testing/generic_mock_repository.py +++ b/src/starlite_saqlalchemy/testing/generic_mock_repository.py @@ -13,7 +13,7 @@ from starlite_saqlalchemy.repository.abc import AbstractRepository if TYPE_CHECKING: - from collections.abc import Callable, Hashable, Iterable, MutableMapping + from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence from typing import Any from starlite_saqlalchemy.repository.types import FilterTypes @@ -71,6 +71,18 @@ async def add(self, data: ModelT, allow_id: bool = False) -> ModelT: self.collection[data.id] = data return data + async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: + """ + + Args: + *filters: Types for specific filtering operations. + **kwargs: Instance attribute value filters. + + Returns: + Count of instances in collection, ignoring pagination. + """ + return len(await self.list(*filters, **kwargs)) + async def delete(self, id_: Any) -> ModelT: """Delete instance identified by `id_`. @@ -102,7 +114,7 @@ async def get(self, id_: Any) -> ModelT: """ return self._find_or_raise_not_found(id_) - async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[ModelT], int]: + async def list(self, *filters: FilterTypes, **kwargs: Any) -> Sequence[ModelT]: """Get a list of instances, optionally filtered. Args: @@ -112,19 +124,21 @@ async def list(self, *filters: "FilterTypes", **kwargs: Any) -> tuple[list[Model Returns: The list of instances, after filtering applied. """ - return list(self.collection.values()), len(list(self.collection.values())) + return tuple(self.collection.values()) - async def count(self, select_: Any | None = None) -> int: - """Get a list of instances, optionally filtered. + async def list_and_count( + self, *filters: FilterTypes, **kwargs: Any + ) -> tuple[Sequence[ModelT], int]: + """ Args: *filters: Types for specific filtering operations. **kwargs: Instance attribute value filters. Returns: - The list of instances, after filtering applied. + Count of records returned by query, ignoring pagination. """ - return len(list(self.collection.values())) + return await self.list(*filters, **kwargs), await self.count(*filters, **kwargs) async def update(self, data: ModelT) -> ModelT: """Update instance with the attribute values present on `data`. @@ -172,7 +186,12 @@ async def upsert(self, data: ModelT) -> ModelT: return await self.update(data) return await self.add(data, allow_id=True) - def filter_collection_by_kwargs(self, **kwargs: Any) -> None: + def filter_collection_by_kwargs( # type:ignore[override] + self, + collection: MutableMapping[Hashable, ModelT], + /, + **kwargs: Any, + ) -> None: """Filter the collection by kwargs. Args: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 3f2b3d9..2500f46 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -166,7 +166,7 @@ async def fx_engine(docker_ip: str) -> AsyncEngine: @pytest.fixture(autouse=True) -async def _seed_db(engine: AsyncEngine, raw_authors: list[dict[str, Any]]) -> AsyncIterator[None]: +async def _seed_db(engine: AsyncEngine, raw_authors: list[dict[str, Any]]) -> None: """Populate test database with. Args: @@ -175,6 +175,7 @@ async def _seed_db(engine: AsyncEngine, raw_authors: list[dict[str, Any]]) -> As metadata = db.orm.Base.registry.metadata author_table = metadata.tables["author"] async with engine.begin() as conn: + await conn.run_sync(metadata.drop_all) await conn.run_sync(metadata.create_all) # convert date/time strings to dt objects. @@ -185,9 +186,6 @@ async def _seed_db(engine: AsyncEngine, raw_authors: list[dict[str, Any]]) -> As async with engine.begin() as conn: await conn.execute(author_table.insert(), raw_authors) - yield - async with engine.begin() as conn: - await conn.run_sync(metadata.drop_all) @pytest.fixture(autouse=True) diff --git a/tests/integration/repository/test_sqlalchemy_repository.py b/tests/integration/repository/test_sqlalchemy_repository.py index ff04d34..8b06c92 100644 --- a/tests/integration/repository/test_sqlalchemy_repository.py +++ b/tests/integration/repository/test_sqlalchemy_repository.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Any + import pytest from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker @@ -16,5 +20,20 @@ def fx_repo(session: AsyncSession) -> authors.Repository: def test_filter_by_kwargs_with_incorrect_attribute_name(repo: authors.Repository) -> None: + select_ = repo._create_select_for_model() with pytest.raises(StarliteSaqlalchemyError): - repo.filter_collection_by_kwargs(whoops="silly me") + repo.filter_collection_by_kwargs(select_, whoops="silly me") + + +async def test_repo_count_method(repo: authors.Repository) -> None: + assert await repo.count() == 2 + + +async def test_repo_list_and_count_method( + raw_authors: list[dict[str, Any]], repo: authors.Repository +) -> None: + exp_count = len(raw_authors) + collection, count = await repo.list_and_count() + assert exp_count == count + assert isinstance(collection, list) + assert len(collection) == exp_count diff --git a/tests/unit/require_sqlalchemy/repository/test_generic_mock_repository.py b/tests/unit/require_sqlalchemy/repository/test_generic_mock_repository.py index c35c9dc..32d6838 100644 --- a/tests/unit/require_sqlalchemy/repository/test_generic_mock_repository.py +++ b/tests/unit/require_sqlalchemy/repository/test_generic_mock_repository.py @@ -47,7 +47,7 @@ def test_generic_mock_repository_filter_collection_by_kwargs( author_repository: GenericMockRepository[Author], ) -> None: """Test filtering the repository collection by kwargs.""" - author_repository.filter_collection_by_kwargs(name="Leo Tolstoy") + author_repository.filter_collection_by_kwargs(author_repository.collection, name="Leo Tolstoy") assert len(author_repository.collection) == 1 assert list(author_repository.collection.values())[0].name == "Leo Tolstoy" @@ -57,7 +57,9 @@ def test_generic_mock_repository_filter_collection_by_kwargs_and_semantics( ) -> None: """Test that filtering by kwargs has `AND` semantics when multiple kwargs, not `OR`.""" - author_repository.filter_collection_by_kwargs(name="Agatha Christie", dob="1828-09-09") + author_repository.filter_collection_by_kwargs( + author_repository.collection, name="Agatha Christie", dob="1828-09-09" + ) assert len(author_repository.collection) == 0 @@ -67,7 +69,7 @@ def test_generic_mock_repository_raises_repository_exception_if_named_attribute_ """Test that a repo exception is raised if a named attribute doesn't exist.""" with pytest.raises(StarliteSaqlalchemyError): - author_repository.filter_collection_by_kwargs(cricket="ball") + author_repository.filter_collection_by_kwargs(author_repository.collection, cricket="ball") async def test_sets_created_updated_on_add() -> None: diff --git a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py index 1d34de8..990d6ad 100644 --- a/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py +++ b/tests/unit/require_sqlalchemy/repository/test_sqlalchemy.py @@ -34,7 +34,10 @@ class Repo(SQLAlchemyRepository[MagicMock]): model_type = MagicMock() # pyright:ignore[reportGeneralTypeIssues] - return Repo(session=AsyncMock(spec=AsyncSession), select_=MagicMock()) + def _create_select_for_model(self) -> MagicMock: + return MagicMock() + + return Repo(session=AsyncMock(spec=AsyncSession)) def test_wrap_sqlalchemy_integrity_error() -> None: @@ -97,15 +100,10 @@ async def test_sqlalchemy_repo_list( mock_instances = [MagicMock(), MagicMock()] result_mock = MagicMock() result_mock.scalars = MagicMock(return_value=mock_instances) - count_mock = MagicMock() - count_mock.return_value = 2 execute_mock = AsyncMock(return_value=result_mock) - execute_count_mock = AsyncMock(return_value=count_mock) - monkeypatch.setattr(mock_repo, "count", execute_count_mock) monkeypatch.setattr(mock_repo, "_execute", execute_mock) - instances, count = await mock_repo.list() + instances = await mock_repo.list() assert instances == mock_instances - assert count == count_mock mock_repo.session.expunge.assert_has_calls(*mock_instances) mock_repo.session.commit.assert_not_called() @@ -117,11 +115,13 @@ async def test_sqlalchemy_repo_list_with_pagination( result_mock = MagicMock() execute_mock = AsyncMock(return_value=result_mock) monkeypatch.setattr(mock_repo, "_execute", execute_mock) - mock_repo._select.limit.return_value = mock_repo._select - mock_repo._select.offset.return_value = mock_repo._select + select_ = MagicMock() + select_.limit.return_value = select_ + select_.offset.return_value = select_ + monkeypatch.setattr(mock_repo, "_create_select_for_model", lambda: select_) await mock_repo.list(LimitOffset(2, 3)) - mock_repo._select.limit.assert_called_once_with(2) - mock_repo._select.limit().offset.assert_called_once_with(3) # type:ignore[call-arg] + select_.limit.assert_called_once_with(2) + select_.limit().offset.assert_called_once_with(3) async def test_sqlalchemy_repo_count( @@ -150,10 +150,12 @@ async def test_sqlalchemy_repo_list_with_before_after_filter( result_mock = MagicMock() execute_mock = AsyncMock(return_value=result_mock) monkeypatch.setattr(mock_repo, "_execute", execute_mock) - mock_repo._select.where.return_value = mock_repo._select + select_ = MagicMock() + select_.where.return_value = select_ + monkeypatch.setattr(mock_repo, "_create_select_for_model", lambda: select_) await mock_repo.list(BeforeAfter(field_name, datetime.max, datetime.min)) - assert mock_repo._select.where.call_count == 2 - assert mock_repo._select.where.has_calls([call("gt"), call("lt")]) + assert select_.where.call_count == 2 + assert select_.where.has_calls([call("gt"), call("lt")]) async def test_sqlalchemy_repo_list_with_collection_filter( @@ -164,10 +166,12 @@ async def test_sqlalchemy_repo_list_with_collection_filter( result_mock = MagicMock() execute_mock = AsyncMock(return_value=result_mock) monkeypatch.setattr(mock_repo, "_execute", execute_mock) - mock_repo._select.where.return_value = mock_repo._select + select_ = MagicMock() + select_.where.return_value = select_ + monkeypatch.setattr(mock_repo, "_create_select_for_model", lambda: select_) values = [1, 2, 3] await mock_repo.list(CollectionFilter(field_name, values)) - mock_repo._select.where.assert_called_once() + select_.where.assert_called_once() getattr(mock_repo.model_type, field_name).in_.assert_called_once_with(values) @@ -220,14 +224,16 @@ async def test_attach_to_session_unexpected_strategy_raises_valueerror( async def test_execute(mock_repo: SQLAlchemyRepository) -> None: """Simple test of the abstraction over `AsyncSession.execute()`""" - await mock_repo._execute() - mock_repo.session.execute.assert_called_once_with(mock_repo._select) + select_ = mock_repo._create_select_for_model() + await mock_repo._execute(select_) + mock_repo.session.execute.assert_called_once_with(select_) def test_filter_in_collection_noop_if_collection_empty(mock_repo: SQLAlchemyRepository) -> None: """Ensures we don't filter on an empty collection.""" - mock_repo._filter_in_collection("id", []) - mock_repo._select.where.assert_not_called() + select_ = mock_repo._create_select_for_model() + mock_repo._filter_in_collection("id", [], select_=select_) + select_.where.assert_not_called() @pytest.mark.parametrize( @@ -245,13 +251,16 @@ def test__filter_on_datetime_field( field_mock = MagicMock() field_mock.__gt__ = field_mock.__lt__ = lambda self, other: True mock_repo.model_type.updated = field_mock - mock_repo._filter_on_datetime_field("updated", before, after) + mock_repo._filter_on_datetime_field( + "updated", before, after, select_=mock_repo._create_select_for_model() + ) def test_filter_collection_by_kwargs(mock_repo: SQLAlchemyRepository) -> None: """Test `filter_by()` called with kwargs.""" - mock_repo.filter_collection_by_kwargs(a=1, b=2) - mock_repo._select.filter_by.assert_called_once_with(a=1, b=2) + select_ = mock_repo._create_select_for_model() + mock_repo.filter_collection_by_kwargs(select_, a=1, b=2) + select_.filter_by.assert_called_once_with(a=1, b=2) def test_filter_collection_by_kwargs_raises_repository_exception_for_attribute_error( @@ -259,8 +268,9 @@ def test_filter_collection_by_kwargs_raises_repository_exception_for_attribute_e ) -> None: """Test that we raise a repository exception if an attribute name is incorrect.""" - mock_repo._select.filter_by = MagicMock( # type:ignore[assignment] + select_ = mock_repo._create_select_for_model() + select_.filter_by = MagicMock( # type:ignore[assignment] side_effect=InvalidRequestError, ) with pytest.raises(StarliteSaqlalchemyError): - mock_repo.filter_collection_by_kwargs(a=1) + mock_repo.filter_collection_by_kwargs(select_, a=1) diff --git a/tests/unit/require_sqlalchemy/test_service.py b/tests/unit/require_sqlalchemy/test_service.py index 14e33d6..647631e 100644 --- a/tests/unit/require_sqlalchemy/test_service.py +++ b/tests/unit/require_sqlalchemy/test_service.py @@ -36,15 +36,15 @@ async def test_service_create() -> None: async def test_service_list() -> None: """Test repository list action.""" - items, count = await domain.authors.Service().list() - assert isinstance(items, list) - assert count == 2 + items = await domain.authors.Service().list() + assert isinstance(items, tuple) + assert len(items) == 2 async def test_service_update() -> None: """Test repository update action.""" service_obj = domain.authors.Service() - authors, _ = await service_obj.list() + authors = await service_obj.list() author = authors[0] assert author.name == "Agatha Christie" author.name = "different" @@ -55,7 +55,7 @@ async def test_service_update() -> None: async def test_service_upsert_update() -> None: """Test repository upsert action for update.""" service_obj = domain.authors.Service() - authors, _ = await service_obj.list() + authors = await service_obj.list() author = authors[0] assert author.name == "Agatha Christie" author.name = "different" @@ -75,7 +75,7 @@ async def test_service_upsert_create() -> None: async def test_service_get() -> None: """Test repository get action.""" service_obj = domain.authors.Service() - authors, _ = await service_obj.list() + authors = await service_obj.list() author = authors[0] retrieved = await service_obj.get(author.id) assert author is retrieved @@ -84,7 +84,7 @@ async def test_service_get() -> None: async def test_service_delete() -> None: """Test repository delete action.""" service_obj = domain.authors.Service() - authors, _ = await service_obj.list() + authors = await service_obj.list() author = authors[0] deleted = await service_obj.delete(author.id) assert author is deleted @@ -100,8 +100,10 @@ async def test_service_method_default_behavior() -> None: """Test default behavior of base service methods.""" service_obj = service.Service[object]() data = object() + assert await service_obj.count() == 0 assert await service_obj.create(data) is data - assert await service_obj.list() == ([], 0) + assert await service_obj.list() == [] + assert await service_obj.list_and_count() == ([], 0) assert await service_obj.update("abc", data) is data assert await service_obj.upsert("abc", data) is data with pytest.raises(NotFoundError): From 2793a8e2fc2ea03ac0db7e546977497b2fb7a654 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 25 Jan 2023 13:18:50 +1000 Subject: [PATCH 09/11] Update src/starlite_saqlalchemy/repository/sqlalchemy.py --- src/starlite_saqlalchemy/repository/sqlalchemy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/starlite_saqlalchemy/repository/sqlalchemy.py b/src/starlite_saqlalchemy/repository/sqlalchemy.py index f3b9ddf..b333302 100644 --- a/src/starlite_saqlalchemy/repository/sqlalchemy.py +++ b/src/starlite_saqlalchemy/repository/sqlalchemy.py @@ -67,7 +67,6 @@ def __init__(self, *, session: AsyncSession, **kwargs: Any) -> None: """ Args: session: Session managing the unit-of-work for the operation. - select_: To facilitate customization of the underlying select query. """ super().__init__(**kwargs) self.session = session From 80a3f3a407cf6c6b7802d32caccdb84b2b587585 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Fri, 27 Jan 2023 08:18:38 +1000 Subject: [PATCH 10/11] Revert "build(deps): pin to starlite <= 1.50.0" This reverts commit 65e6bfecc654718e42b627d98f8657fc1cae9810. --- poetry.lock | 10 +++++----- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index ed12a85..307736b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1042,14 +1042,14 @@ sqlcipher = ["sqlcipher3-binary"] [[package]] name = "starlite" -version = "1.50.0" +version = "1.50.1" description = "Performant, light and flexible ASGI API Framework" category = "main" optional = false python-versions = ">=3.8,<4.0" files = [ - {file = "starlite-1.50.0-py3-none-any.whl", hash = "sha256:9d3379dec91fdfa0fb4a2cb77619e41455b89f280258f253de361f2f94464313"}, - {file = "starlite-1.50.0.tar.gz", hash = "sha256:6ed8572252cbb60cf70909cfb16721c0cc59d95c9fb5eaf8a4b0c7e8cfd68313"}, + {file = "starlite-1.50.1-py3-none-any.whl", hash = "sha256:31d45efc805b895b9733b58825942129e55ebbbf8ddaeaf931b1237af5daae37"}, + {file = "starlite-1.50.1.tar.gz", hash = "sha256:7cc30ace31c47c406666baff839e4dfa24e205855ef0fada011b9e1bbadda852"}, ] [package.dependencies] @@ -1070,7 +1070,7 @@ brotli = ["brotli"] cli = ["click", "jsbeautifier", "rich (>=13.0.0)"] cryptography = ["cryptography"] full = ["aiomcache", "brotli", "click", "cryptography", "jinja2 (>=3.1.2)", "opentelemetry-instrumentation-asgi", "picologging", "python-jose", "redis[hiredis]", "rich (>=13.0.0)", "structlog"] -jina = ["jinja2 (>=3.1.2)"] +jinja = ["jinja2 (>=3.1.2)"] jwt = ["cryptography", "python-jose"] memcached = ["aiomcache"] opentelemetry = ["opentelemetry-instrumentation-asgi"] @@ -1215,4 +1215,4 @@ worker = ["saq", "hiredis"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "f614b5407e536db1af6f68bee978b170304c642eee949564a23604729b90c167" +content-hash = "7a5ed929059b3c670d3a4d40e6ac7cf871900e33ff937dd1ff34e615b36f56c1" diff --git a/pyproject.toml b/pyproject.toml index 42566e3..7a22583 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ httpx = "*" msgspec = "*" pydantic = "*" python-dotenv = "*" -starlite = ">=1.40.1,<=1.50.0" +starlite = ">=1.40.1,<1.50.2" tenacity = "*" uvicorn = "*" uvloop = "*" From 89f13a1b6073089d7a927f53d88d04f3de6128cf Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Sun, 29 Jan 2023 22:50:54 +1000 Subject: [PATCH 11/11] feat(service): add `count()`/`list_and_count()` to `RepositoryService`. --- src/starlite_saqlalchemy/service/sqlalchemy.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/starlite_saqlalchemy/service/sqlalchemy.py b/src/starlite_saqlalchemy/service/sqlalchemy.py index de1a626..b541da2 100644 --- a/src/starlite_saqlalchemy/service/sqlalchemy.py +++ b/src/starlite_saqlalchemy/service/sqlalchemy.py @@ -42,7 +42,8 @@ def __init__(self, **repo_kwargs: Any) -> None: async def count(self, *filters: FilterTypes, **kwargs: Any) -> int: """ Args: - **kwargs: key value pairs of filter types. + *filters: Collection route filters. + **kwargs: Keyword arguments for attribute based filtering. Returns: A count of the collection, filtered, but ignoring pagination. @@ -77,7 +78,8 @@ async def list_and_count( ) -> tuple[Sequence[ModelT], int]: """ Args: - **kwargs: Keyword arguments for filtering. + *filters: Collection route filters. + **kwargs: Keyword arguments for attribute based filtering. Returns: List of instances and count of total collection, ignoring pagination.