From 61da7b4418616b39984eeaf02551b258b21044d4 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 1 Dec 2020 08:27:08 +0100 Subject: [PATCH 1/3] expose querysetproxy on reverse of foreignkey (virtual fk), add additional methods from queryset to querysetproxy --- docs/queries.md | 2 +- ormar/exceptions.py | 4 + ormar/models/newbasemodel.py | 33 ++-- ormar/protocols/queryset_protocol.py | 34 +++- ormar/relations/querysetproxy.py | 128 +++++++++++--- ormar/relations/relation.py | 2 +- ormar/relations/relation_manager.py | 2 - ormar/relations/relation_proxy.py | 46 +++-- tests/test_models.py | 7 +- tests/test_queryproxy_on_m2m_models.py | 182 +++++++++++++++++++ tests/test_reverse_fk_queryset.py | 233 +++++++++++++++++++++++++ 11 files changed, 605 insertions(+), 68 deletions(-) create mode 100644 tests/test_queryproxy_on_m2m_models.py create mode 100644 tests/test_reverse_fk_queryset.py diff --git a/docs/queries.md b/docs/queries.md index c848bc1ad..17b23ceee 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -212,7 +212,7 @@ You can use special filter suffix to change the filter operands: * exact - like `album__name__exact='Malibu'` (exact match) * iexact - like `album__name__iexact='malibu'` (exact match case insensitive) -* contains - like `album__name__conatins='Mal'` (sql like) +* contains - like `album__name__contains='Mal'` (sql like) * icontains - like `album__name__icontains='mal'` (sql like case insensitive) * in - like `album__name__in=['Malibu', 'Barclay']` (sql in) * gt - like `position__gt=3` (sql >) diff --git a/ormar/exceptions.py b/ormar/exceptions.py index 0800a8105..0ec0c8ef2 100644 --- a/ormar/exceptions.py +++ b/ormar/exceptions.py @@ -6,6 +6,10 @@ class ModelDefinitionError(AsyncOrmException): pass +class ModelError(AsyncOrmException): + pass + + class ModelNotSet(AsyncOrmException): pass diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 8a9d518d0..05c31120c 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -7,6 +7,7 @@ Dict, List, Mapping, + MutableSequence, Optional, Sequence, Set, @@ -22,6 +23,7 @@ from pydantic import BaseModel import ormar # noqa I100 +from ormar.exceptions import ModelError from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.models.excludable import Excludable @@ -93,16 +95,21 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") # build the models to set them and validate but don't register - new_kwargs = { - k: self._convert_json( - k, - self.Meta.model_fields[k].expand_relationship( - v, self, to_register=False - ), - "dumps", + try: + new_kwargs = { + k: self._convert_json( + k, + self.Meta.model_fields[k].expand_relationship( + v, self, to_register=False + ), + "dumps", + ) + for k, v in kwargs.items() + } + except KeyError as e: + raise ModelError( + f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}" ) - for k, v in kwargs.items() - } values, fields_set, validation_error = pydantic.validate_model( self, new_kwargs # type: ignore @@ -249,7 +256,9 @@ def _get_related_not_excluded_fields( @staticmethod def _extract_nested_models_from_list( - models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None], + models: MutableSequence, + include: Union[Set, Dict, None], + exclude: Union[Set, Dict, None], ) -> List: result = [] for model in models: @@ -282,7 +291,7 @@ def _extract_nested_models( # noqa: CCR001 if self.Meta.model_fields[field].virtual and nested: continue nested_model = getattr(self, field) - if isinstance(nested_model, list): + if isinstance(nested_model, MutableSequence): dict_instance[field] = self._extract_nested_models_from_list( models=nested_model, include=self._skip_ellipsis(include, field), @@ -308,7 +317,7 @@ def dict( # type: ignore # noqa A003 exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - nested: bool = False + nested: bool = False, ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, diff --git a/ormar/protocols/queryset_protocol.py b/ormar/protocols/queryset_protocol.py index 1320c2a12..7eb7092de 100644 --- a/ormar/protocols/queryset_protocol.py +++ b/ormar/protocols/queryset_protocol.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union try: from typing import Protocol @@ -6,14 +6,21 @@ from typing_extensions import Protocol # type: ignore if TYPE_CHECKING: # noqa: C901; #pragma nocover - from ormar import QuerySet, Model + from ormar import Model + from ormar.relations.querysetproxy import QuerysetProxy class QuerySetProtocol(Protocol): # pragma: nocover - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003, A001 + def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 ... - def select_related(self, related: Union[List, str]) -> "QuerySet": + def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + ... + + def select_related(self, related: Union[List, str]) -> "QuerysetProxy": + ... + + def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy": ... async def exists(self) -> bool: @@ -25,10 +32,10 @@ async def count(self) -> int: async def clear(self) -> int: ... - def limit(self, limit_count: int) -> "QuerySet": + def limit(self, limit_count: int) -> "QuerysetProxy": ... - def offset(self, offset: int) -> "QuerySet": + def offset(self, offset: int) -> "QuerysetProxy": ... async def first(self, **kwargs: Any) -> "Model": @@ -44,3 +51,18 @@ async def all( # noqa: A003, A001 async def create(self, **kwargs: Any) -> "Model": ... + + async def get_or_create(self, **kwargs: Any) -> "Model": + ... + + async def update_or_create(self, **kwargs: Any) -> "Model": + ... + + def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + ... + + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + ... + + def order_by(self, columns: Union[List, str]) -> "QuerysetProxy": + ... diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 8e2d6b4fb..208b4d123 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -1,4 +1,15 @@ -from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union +from typing import ( + Any, + Dict, + List, + MutableSequence, + Optional, + Sequence, + Set, + TYPE_CHECKING, + TypeVar, + Union, +) import ormar @@ -6,6 +17,7 @@ from ormar.relations import Relation from ormar.models import Model from ormar.queryset import QuerySet + from ormar import RelationType T = TypeVar("T", bound=Model) @@ -14,9 +26,17 @@ class QuerysetProxy(ormar.QuerySetProtocol): if TYPE_CHECKING: # pragma no cover relation: "Relation" - def __init__(self, relation: "Relation") -> None: + def __init__( + self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None + ) -> None: self.relation: Relation = relation - self._queryset: Optional["QuerySet"] = None + self._queryset: Optional["QuerySet"] = qryset + self.type_: "RelationType" = type_ + self._owner: "Model" = self.relation.manager.owner + self.related_field = self._owner.resolve_relation_field( + self.relation.to, self._owner + ) + self.owner_pk_value = self._owner.pk @property def queryset(self) -> "QuerySet": @@ -30,7 +50,7 @@ def queryset(self, value: "QuerySet") -> None: def _assign_child_to_parent(self, child: Optional["T"]) -> None: if child: - owner = self.relation._owner + owner = self._owner rel_name = owner.resolve_relation_name(owner, child) setattr(owner, rel_name, child) @@ -42,27 +62,26 @@ def _register_related(self, child: Union["T", Sequence[Optional["T"]]]) -> None: assert isinstance(child, ormar.Model) self._assign_child_to_parent(child) + def _clean_items_on_load(self) -> None: + if isinstance(self.relation.related_models, MutableSequence): + for item in self.relation.related_models[:]: + self.relation.remove(item) + async def create_through_instance(self, child: "T") -> None: queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self.relation._owner.get_name() + owner_column = self._owner.get_name() child_column = child.get_name() - kwargs = {owner_column: self.relation._owner, child_column: child} + kwargs = {owner_column: self._owner, child_column: child} await queryset.create(**kwargs) async def delete_through_instance(self, child: "T") -> None: queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self.relation._owner.get_name() + owner_column = self._owner.get_name() child_column = child.get_name() - kwargs = {owner_column: self.relation._owner, child_column: child} + kwargs = {owner_column: self._owner, child_column: child} link_instance = await queryset.filter(**kwargs).get() # type: ignore await link_instance.delete() - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 - return self.queryset.filter(**kwargs) - - def select_related(self, related: Union[List, str]) -> "QuerySet": - return self.queryset.select_related(related) - async def exists(self) -> bool: return await self.queryset.exists() @@ -70,17 +89,16 @@ async def count(self) -> int: return await self.queryset.count() async def clear(self) -> int: - queryset = ormar.QuerySet(model_cls=self.relation.through) - owner_column = self.relation._owner.get_name() - kwargs = {owner_column: self.relation._owner} + if self.type_ == ormar.RelationType.MULTIPLE: + queryset = ormar.QuerySet(model_cls=self.relation.through) + owner_column = self._owner.get_name() + else: + queryset = ormar.QuerySet(model_cls=self.relation.to) + owner_column = self.related_field.name + kwargs = {owner_column: self._owner} + self._clean_items_on_load() return await queryset.delete(**kwargs) # type: ignore - def limit(self, limit_count: int) -> "QuerySet": - return self.queryset.limit(limit_count) - - def offset(self, offset: int) -> "QuerySet": - return self.queryset.offset(offset) - async def first(self, **kwargs: Any) -> "Model": first = await self.queryset.first(**kwargs) self._register_related(first) @@ -88,16 +106,72 @@ async def first(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "Model": get = await self.queryset.get(**kwargs) + self._clean_items_on_load() self._register_related(get) return get async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 all_items = await self.queryset.all(**kwargs) + self._clean_items_on_load() self._register_related(all_items) return all_items async def create(self, **kwargs: Any) -> "Model": - create = await self.queryset.create(**kwargs) - self._register_related(create) - await self.create_through_instance(create) - return create + if self.type_ == ormar.RelationType.REVERSE: + kwargs[self.related_field.name] = self._owner + created = await self.queryset.create(**kwargs) + self._register_related(created) + if self.type_ == ormar.RelationType.MULTIPLE: + await self.create_through_instance(created) + return created + + async def get_or_create(self, **kwargs: Any) -> "Model": + try: + return await self.get(**kwargs) + except ormar.NoMatch: + return await self.create(**kwargs) + + async def update_or_create(self, **kwargs: Any) -> "Model": + pk_name = self.queryset.model_meta.pkname + if "pk" in kwargs: + kwargs[pk_name] = kwargs.pop("pk") + if pk_name not in kwargs or kwargs.get(pk_name) is None: + return await self.create(**kwargs) + model = await self.queryset.get(pk=kwargs[pk_name]) + return await model.update(**kwargs) + + def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + queryset = self.queryset.filter(**kwargs) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 + queryset = self.queryset.exclude(**kwargs) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def select_related(self, related: Union[List, str]) -> "QuerysetProxy": + queryset = self.queryset.select_related(related) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy": + queryset = self.queryset.prefetch_related(related) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def limit(self, limit_count: int) -> "QuerysetProxy": + queryset = self.queryset.limit(limit_count) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def offset(self, offset: int) -> "QuerysetProxy": + queryset = self.queryset.offset(offset) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + queryset = self.queryset.fields(columns) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": + queryset = self.queryset.exclude_fields(columns=columns) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) + + def order_by(self, columns: Union[List, str]) -> "QuerysetProxy": + queryset = self.queryset.order_by(columns) + return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index e09f00ce6..b0183c3e7 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -34,7 +34,7 @@ def __init__( self.to: Type["T"] = to self.through: Optional[Type["T"]] = through self.related_models: Optional[Union[RelationProxy, "T"]] = ( - RelationProxy(relation=self) + RelationProxy(relation=self, type_=type_) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) else None ) diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index dfe1ee8ff..81183f4f9 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -65,8 +65,6 @@ def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None parent_relation = parent._orm._get(child_name) if parent_relation: - # print('missing', child_name) - # parent_relation = register_missing_relation(parent, child, child_name) parent_relation.add(child) # type: ignore child_relation = child._orm._get(to_name) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index f03252e39..25c7b1ca2 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -5,17 +5,18 @@ from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover - from ormar import Model + from ormar import Model, RelationType from ormar.relations import Relation from ormar.queryset import QuerySet class RelationProxy(list): - def __init__(self, relation: "Relation") -> None: - super(RelationProxy, self).__init__() - self.relation: Relation = relation + def __init__(self, relation: "Relation", type_: "RelationType") -> None: + super().__init__() + self.relation: "Relation" = relation + self.type_: "RelationType" = type_ self._owner: "Model" = self.relation.manager.owner - self.queryset_proxy = QuerysetProxy(relation=self.relation) + self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_) def __getattribute__(self, item: str) -> Any: if item in ["count", "clear"]: @@ -38,17 +39,19 @@ def _check_if_queryset_is_initialized(self) -> bool: ) def _set_queryset(self) -> "QuerySet": - owner_table = self.relation._owner.Meta.tablename - pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname) - pk_value = self.relation._owner.pk + related_field = self._owner.resolve_relation_field( + self.relation.to, self._owner + ) + pkname = self._owner.get_column_alias(self._owner.Meta.pkname) + pk_value = self._owner.pk if not pk_value: raise RelationshipInstanceError( - "You cannot query many to many relationship on unsaved model." + "You cannot query relationships from unsaved model." ) - kwargs = {f"{owner_table}__{pkname}": pk_value} + kwargs = {f"{related_field.get_alias()}__{pkname}": pk_value} queryset = ( ormar.QuerySet(model_cls=self.relation.to) - .select_related(owner_table) + .select_related(related_field.name) .filter(**kwargs) ) return queryset @@ -67,14 +70,21 @@ async def remove(self, item: "Model") -> None: # type: ignore f"{self._owner.get_name()} does not have relation {rel_name}" ) relation.remove(self._owner) - if self.relation._type == ormar.RelationType.MULTIPLE: + self.relation.remove(item) + if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.delete_through_instance(item) - - def append(self, item: "Model") -> None: - super().append(item) + else: + setattr(item, rel_name, None) + await item.update() async def add(self, item: "Model") -> None: - if self.relation._type == ormar.RelationType.MULTIPLE: + if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.create_through_instance(item) - rel_name = item.resolve_relation_name(item, self._owner) - setattr(item, rel_name, self._owner) + rel_name = item.resolve_relation_name(item, self._owner) + setattr(item, rel_name, self._owner) + else: + related_field = self._owner.resolve_relation_field( + self.relation.to, self._owner + ) + setattr(item, related_field.name, self._owner) + await item.update() diff --git a/tests/test_models.py b/tests/test_models.py index 3c9ffc81d..24a81ade1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ import sqlalchemy import ormar -from ormar.exceptions import QueryDefinitionError, NoMatch +from ormar.exceptions import QueryDefinitionError, NoMatch, ModelError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -117,6 +117,11 @@ def test_model_class(): assert isinstance(User.Meta.table, sqlalchemy.Table) +def test_wrong_field_name(): + with pytest.raises(ModelError): + User(non_existing_pk=1) + + def test_model_pk(): user = User(pk=1) assert user.pk == 1 diff --git a/tests/test_queryproxy_on_m2m_models.py b/tests/test_queryproxy_on_m2m_models.py new file mode 100644 index 000000000..d33aa5d84 --- /dev/null +++ b/tests/test_queryproxy_on_m2m_models.py @@ -0,0 +1,182 @@ +import asyncio +from typing import List, Optional, Union + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Subject(ormar.Model): + class Meta: + tablename = "subjects" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=80) + + +class Author(ormar.Model): + class Meta: + tablename = "authors" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + first_name: str = ormar.String(max_length=80) + last_name: str = ormar.String(max_length=80) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=40) + sort_order: int = ormar.Integer(nullable=True) + subject: Optional[Subject] = ormar.ForeignKey(Subject) + + +class PostCategory(ormar.Model): + class Meta: + tablename = "posts_categories" + database = database + metadata = metadata + + +class Post(ormar.Model): + class Meta: + tablename = "posts" + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany( + Category, through=PostCategory + ) + author: Optional[Author] = ormar.ForeignKey(Author) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_queryset_methods(): + async with database: + async with database.transaction(force_rollback=True): + guido = await Author.objects.create( + first_name="Guido", last_name="Van Rossum" + ) + subject = await Subject(name="Random").save() + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create( + name="News", sort_order=1, subject=subject + ) + breaking = await Category.objects.create( + name="Breaking", sort_order=3, subject=subject + ) + + # Add a category to a post. + await post.categories.add(news) + await post.categories.add(breaking) + + category = await post.categories.get_or_create(name="News") + assert category == news + assert len(post.categories) == 1 + + category = await post.categories.get_or_create(name="Breaking News") + assert category != breaking + assert category.pk is not None + assert len(post.categories) == 2 + + await post.categories.update_or_create(pk=category.pk, name="Urgent News") + assert len(post.categories) == 2 + cat = await post.categories.get_or_create(name="Urgent News") + assert cat.pk == category.pk + assert len(post.categories) == 1 + + await post.categories.remove(cat) + await cat.delete() + + assert len(post.categories) == 0 + + category = await post.categories.update_or_create( + name="Weather News", sort_order=2, subject=subject + ) + assert category.pk is not None + assert category.posts[0] == post + + assert len(post.categories) == 1 + + categories = await post.categories.all() + assert len(categories) == 3 == len(post.categories) + + assert await post.categories.exists() + assert 3 == await post.categories.count() + + categories = await post.categories.limit(2).all() + assert len(categories) == 2 == len(post.categories) + + categories2 = await post.categories.limit(2).offset(1).all() + assert len(categories2) == 2 == len(post.categories) + assert categories != categories2 + + categories = await post.categories.order_by("-sort_order").all() + assert len(categories) == 3 == len(post.categories) + assert post.categories[2].name == "News" + assert post.categories[0].name == "Breaking" + + categories = await post.categories.exclude(name__icontains="news").all() + assert len(categories) == 1 == len(post.categories) + assert post.categories[0].name == "Breaking" + + categories = ( + await post.categories.filter(name__icontains="news") + .order_by("-name") + .all() + ) + assert len(categories) == 2 == len(post.categories) + assert post.categories[0].name == "Weather News" + assert post.categories[1].name == "News" + + categories = await post.categories.fields("name").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.sort_order is None + + categories = await post.categories.exclude_fields("sort_order").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.sort_order is None + assert cat.subject.name is None + + categories = await post.categories.select_related("subject").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.subject.name is not None + + categories = await post.categories.prefetch_related("subject").all() + assert len(categories) == 3 == len(post.categories) + for cat in post.categories: + assert cat.subject.name is not None diff --git a/tests/test_reverse_fk_queryset.py b/tests/test_reverse_fk_queryset.py new file mode 100644 index 000000000..0eac5d5ab --- /dev/null +++ b/tests/test_reverse_fk_queryset.py @@ -0,0 +1,233 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + is_best_seller: bool = ormar.Boolean(default=False) + + +class Writer(ormar.Model): + class Meta: + tablename = "writers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Track(ormar.Model): + class Meta: + tablename = "tracks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album) + title: str = ormar.String(max_length=100) + position: int = ormar.Integer() + play_count: int = ormar.Integer(nullable=True) + written_by: Optional[Writer] = ormar.ForeignKey(Writer) + + +@pytest.fixture(autouse=True) +@pytest.mark.asyncio +async def sample_data(): + album = await Album(name="Malibu").save() + writer1 = await Writer.objects.create(name="John") + writer2 = await Writer.objects.create(name="Sue") + track1 = await Track( + album=album, title="The Bird", position=1, play_count=30, written_by=writer1 + ).save() + track2 = await Track( + album=album, + title="Heart don't stand a chance", + position=2, + play_count=20, + written_by=writer2, + ).save() + tracks3 = await Track( + album=album, title="The Waters", position=3, play_count=10, written_by=writer1 + ).save() + return album, [track1, track2, tracks3] + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_quering_by_reverse_fk(sample_data): + async with database: + async with database.transaction(force_rollback=True): + track1 = sample_data[1][0] + album = await Album.objects.first() + + assert await album.tracks.exists() + assert await album.tracks.count() == 3 + + track = await album.tracks.get_or_create( + title="The Bird", position=1, play_count=30 + ) + assert track == track1 + assert len(album.tracks) == 1 + + track = await album.tracks.get_or_create( + title="The Bird2", position=4, play_count=5 + ) + assert track != track1 + assert track.pk is not None + assert len(album.tracks) == 2 + + await album.tracks.update_or_create(pk=track.pk, play_count=50) + assert len(album.tracks) == 2 + track = await album.tracks.get_or_create(title="The Bird2") + assert track.play_count == 50 + assert len(album.tracks) == 1 + + await album.tracks.remove(track) + assert track.album is None + await track.delete() + + assert len(album.tracks) == 0 + + track6 = await album.tracks.update_or_create( + title="The Bird3", position=4, play_count=5 + ) + assert track6.pk is not None + assert track6.play_count == 5 + + assert len(album.tracks) == 1 + + await album.tracks.remove(track6) + assert track6.album is None + await track6.delete() + + assert len(album.tracks) == 0 + + +@pytest.mark.asyncio +async def test_getting(sample_data): + async with database: + async with database.transaction(force_rollback=True): + album = sample_data[0] + track1 = await album.tracks.fields(["album", "title", "position"]).get( + title="The Bird" + ) + track2 = await album.tracks.exclude_fields("play_count").get( + title="The Bird" + ) + for track in [track1, track2]: + assert track.title == "The Bird" + assert track.album == album + assert track.play_count is None + + assert len(album.tracks) == 1 + + tracks = await album.tracks.all() + assert len(tracks) == 3 + + assert len(album.tracks) == 3 + + tracks = await album.tracks.order_by("play_count").all() + assert len(tracks) == 3 + assert tracks[0].title == "The Waters" + assert tracks[2].title == "The Bird" + + assert len(album.tracks) == 3 + + track = await album.tracks.create( + title="The Bird Fly Away", position=4, play_count=10 + ) + assert track.title == "The Bird Fly Away" + assert track.position == 4 + assert track.album == album + + assert len(album.tracks) == 4 + + tracks = await album.tracks.all() + assert len(tracks) == 4 + + tracks = await album.tracks.limit(2).all() + assert len(tracks) == 2 + + tracks2 = await album.tracks.limit(2).offset(2).all() + assert len(tracks2) == 2 + assert tracks != tracks2 + + tracks3 = await album.tracks.filter(play_count__lt=15).all() + assert len(tracks3) == 2 + + tracks4 = await album.tracks.exclude(play_count__lt=15).all() + assert len(tracks4) == 2 + assert tracks3 != tracks4 + + assert len(album.tracks) == 2 + + await album.tracks.clear() + tracks = await album.tracks.all() + assert len(tracks) == 0 + assert len(album.tracks) == 0 + + +@pytest.mark.asyncio +async def test_loading_related(sample_data): + async with database: + async with database.transaction(force_rollback=True): + album = sample_data[0] + tracks = await album.tracks.select_related("written_by").all() + assert len(tracks) == 3 + assert len(album.tracks) == 3 + for track in tracks: + assert track.written_by is not None + + tracks = await album.tracks.prefetch_related("written_by").all() + assert len(tracks) == 3 + assert len(album.tracks) == 3 + for track in tracks: + assert track.written_by is not None + + +@pytest.mark.asyncio +async def test_adding_removing(sample_data): + async with database: + async with database.transaction(force_rollback=True): + album = sample_data[0] + track_new = await Track(title="Rainbow", position=5, play_count=300).save() + await album.tracks.add(track_new) + assert track_new.album == album + assert len(album.tracks) == 4 + + track_check = await Track.objects.get(title="Rainbow") + assert track_check.album == album + + track_test = await Track.objects.get(title="Rainbow") + assert track_test.album == album + + await album.tracks.remove(track_new) + assert track_new.album is None + assert len(album.tracks) == 3 + + track_test = await Track.objects.get(title="Rainbow") + assert track_test.album is None From 4c4e6248b05a64c182f3ea7fb97fa499cc020eba Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 1 Dec 2020 08:34:26 +0100 Subject: [PATCH 2/3] fix for sample data in tests --- tests/test_reverse_fk_queryset.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_reverse_fk_queryset.py b/tests/test_reverse_fk_queryset.py index 0eac5d5ab..e79d55be4 100644 --- a/tests/test_reverse_fk_queryset.py +++ b/tests/test_reverse_fk_queryset.py @@ -46,9 +46,7 @@ class Meta: written_by: Optional[Writer] = ormar.ForeignKey(Writer) -@pytest.fixture(autouse=True) -@pytest.mark.asyncio -async def sample_data(): +async def get_sample_data(): album = await Album(name="Malibu").save() writer1 = await Writer.objects.create(name="John") writer2 = await Writer.objects.create(name="Sue") @@ -78,9 +76,10 @@ def create_test_database(): @pytest.mark.asyncio -async def test_quering_by_reverse_fk(sample_data): +async def test_quering_by_reverse_fk(): async with database: async with database.transaction(force_rollback=True): + sample_data = await get_sample_data() track1 = sample_data[1][0] album = await Album.objects.first() @@ -128,9 +127,10 @@ async def test_quering_by_reverse_fk(sample_data): @pytest.mark.asyncio -async def test_getting(sample_data): +async def test_getting(): async with database: async with database.transaction(force_rollback=True): + sample_data = await get_sample_data() album = sample_data[0] track1 = await album.tracks.fields(["album", "title", "position"]).get( title="The Bird" @@ -192,9 +192,10 @@ async def test_getting(sample_data): @pytest.mark.asyncio -async def test_loading_related(sample_data): +async def test_loading_related(): async with database: async with database.transaction(force_rollback=True): + sample_data = await get_sample_data() album = sample_data[0] tracks = await album.tracks.select_related("written_by").all() assert len(tracks) == 3 @@ -210,9 +211,10 @@ async def test_loading_related(sample_data): @pytest.mark.asyncio -async def test_adding_removing(sample_data): +async def test_adding_removing(): async with database: async with database.transaction(force_rollback=True): + sample_data = await get_sample_data() album = sample_data[0] track_new = await Track(title="Rainbow", position=5, play_count=300).save() await album.tracks.add(track_new) From 3ac767ed0f470438d1b1b8341bf60628775d585c Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 1 Dec 2020 10:41:07 +0100 Subject: [PATCH 3/3] bump version, update docs --- README.md | 2 +- docs/index.md | 2 +- docs/queries.md | 2 +- docs/relations.md | 313 ++++++++++++++++++++++++++---- docs/releases.md | 12 ++ docs_src/fields/docs001.py | 2 +- mkdocs.yml | 2 +- ormar/__init__.py | 2 +- ormar/models/model.py | 6 +- ormar/queryset/queryset.py | 4 +- ormar/relations/querysetproxy.py | 8 +- ormar/relations/relation_proxy.py | 27 ++- tests/test_reverse_fk_queryset.py | 29 ++- 13 files changed, 347 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index b4b1890fa..f7adfbe26 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ assert len(tracks) == 1 * `bulk_create(objects: List[Model]) -> None` * `bulk_update(objects: List[Model], columns: List[str] = None) -> None` * `delete(each: bool = False, **kwargs) -> int` -* `all(self, **kwargs) -> List[Optional[Model]]` +* `all(**kwargs) -> List[Optional[Model]]` * `filter(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` diff --git a/docs/index.md b/docs/index.md index b4b1890fa..f7adfbe26 100644 --- a/docs/index.md +++ b/docs/index.md @@ -154,7 +154,7 @@ assert len(tracks) == 1 * `bulk_create(objects: List[Model]) -> None` * `bulk_update(objects: List[Model], columns: List[str] = None) -> None` * `delete(each: bool = False, **kwargs) -> int` -* `all(self, **kwargs) -> List[Optional[Model]]` +* `all(**kwargs) -> List[Optional[Model]]` * `filter(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet` diff --git a/docs/queries.md b/docs/queries.md index 17b23ceee..1af43826e 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -176,7 +176,7 @@ Return number of rows deleted. ### all -`all(self, **kwargs) -> List[Optional["Model"]]` +`all(**kwargs) -> List[Optional["Model"]]` Returns all rows from a database for given model for set filter options. diff --git a/docs/relations.md b/docs/relations.md index 5bda020ad..b3e548b97 100644 --- a/docs/relations.md +++ b/docs/relations.md @@ -29,6 +29,83 @@ By default it's child (source) `Model` name + s, like courses in snippet below: --8<-- "../docs_src/fields/docs001.py" ``` +Reverse relation exposes API to manage related objects also from parent side. + +##### add + +Adding child model from parent side causes adding related model to currently loaded parent relation, +as well as sets child's model foreign key value and updates the model. + +```python +department = await Department(name="Science").save() +course = Course(name="Math", completed=False) # note - not saved + +await department.courses.add(course) +assert course.pk is not None # child model was saved +# relation on child model is set and FK column saved in db +assert courses.department == department +# relation on parent model is also set +assert department.courses[0] == course +``` + +!!!warning + If you want to add child model on related model the primary key value for parent model **has to exist in database**. + + Otherwise ormar will raise RelationshipInstanceError as it cannot set child's ForeignKey column value + if parent model has no primary key value. + + That means that in example above the department has to be saved before you can call `department.courses.add()`. + +##### remove + +Removal of the related model one by one. + +In reverse relation calling `remove()` does not remove the child model, but instead nulls it ForeignKey value. + +```python +# continuing from above +await department.courses.remove(course) +assert len(department.courses) == 0 +# course still exists and was saved in remove +assert course.pk is not None +assert course.department is None + +# to remove child from db +await course.delete() +``` + +But if you want to clear the relation and delete the child at the same time you can issue: + +```python +# this will not only clear the relation +# but also delete related course from db +await department.courses.remove(course, keep_reversed=False) +``` + +##### clear + +Removal of all related models in one call. + +Like remove by default `clear()` nulls the ForeigKey column on child model (all, not matter if they are loaded or not). + +```python +# nulls department column on all courses related to this department +await department.courses.clear() +``` + +If you want to remove the children altogether from the database, set `keep_reversed=False` + +```python +# deletes from db all courses related to this department +await department.courses.clear(keep_reversed=False) +``` + +##### QuerysetProxy + +Reverse relation exposes QuerysetProxy API that allows you to query related model like you would issue a normal Query. + +To read which methods of QuerySet are available read below [querysetproxy][querysetproxy] + #### related_name But you can overwrite this name by providing `related_name` parameter like below: @@ -94,7 +171,7 @@ Sqlalchemy column and Type are automatically taken from target `Model`. * Sqlalchemy column: class of a target `Model` primary key column * Type (used for pydantic): type of a target `Model` -####Defining `Models`: +####Defining `Models` ```Python --8<-- "../docs_src/relations/docs002.py" @@ -107,7 +184,7 @@ post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") ``` -#### Adding related models +#### add ```python # Add a category to a post. @@ -121,26 +198,7 @@ await news.posts.add(post) Otherwise an IntegrityError will be raised by your database driver library. -#### create() - -Create related `Model` directly from parent `Model`. - -The link table is automatically populated, as well as relation ids in the database. - -```python -# Creating columns object from instance: -await post.categories.create(name="Tips") -assert len(await post.categories.all()) == 2 -# newly created instance already have relation persisted in the database -``` - -!!!note - Note that when accessing QuerySet API methods through ManyToMany relation you don't - need to use objects attribute like in normal queries. - - To learn more about available QuerySet methods visit [queries][queries] - -#### remove() +#### remove Removal of the related model one by one. @@ -150,9 +208,9 @@ Removes also the relation in the database. await news.posts.remove(post) ``` -#### clear() +#### clear -Removal all related models in one call. +Removal of all related models in one call. Removes also the relation in the database. @@ -160,17 +218,75 @@ Removes also the relation in the database. await news.posts.clear() ``` -#### Other queryset methods +#### QuerysetProxy + +Reverse relation exposes QuerysetProxy API that allows you to query related model like you would issue a normal Query. + +To read which methods of QuerySet are available read below [querysetproxy][querysetproxy] + +### QuerySetProxy + +When access directly the related `ManyToMany` field as well as `ReverseForeignKey` returns the list of related models. + +But at the same time it exposes subset of QuerySet API, so you can filter, create, select related etc related models directly from parent model. -When access directly the related `ManyToMany` field returns the list of related models. +!!!note + By default exposed QuerySet is already filtered to return only `Models` related to parent `Model`. + + So if you issue `post.categories.all()` you will get all categories related to that post, not all in table. + +!!!note + Note that when accessing QuerySet API methods through QuerysetProxy you don't + need to use `objects` attribute like in normal queries. + + So note that it's `post.categories.all()` and **not** `post.categories.objects.all()`. + + To learn more about available QuerySet methods visit [queries][queries] + +!!!warning + Querying related models from ManyToMany cleans list of related models loaded on parent model: + + Example: `post.categories.first()` will set post.categories to list of 1 related model -> the one returned by first() + + Example 2: if post has 4 categories so `len(post.categories) == 4` calling `post.categories.limit(2).all()` + -> will load only 2 children and now `assert len(post.categories) == 2` + + This happens for all QuerysetProxy methods returning data: `get`, `all` and `first` and in `get_or_create` if model already exists. + + Note that value returned by `create` or created in `get_or_create` and `update_or_create` + if model does not exist will be added to relation list (not clearing it). + +#### get -But at the same time it exposes full QuerySet API, so you can filter, create, select related etc. +`get(**kwargs): -> Model` + +To grab just one of related models filtered by name you can use `get(**kwargs)` method. ```python -# Many to many relation exposes a list of columns models -# and an API of the Queryset: +# grab one category assert news == await post.categories.get(name="News") +# note that method returns the category so you can grab this value +# but it also modifies list of related models in place +# so regardless of what was previously loaded on parent model +# now it has only one value -> just loaded with get() call +assert len(post.categories) == 1 +assert post.categories[0] == news + +``` + +!!!tip + Read more in queries documentation [get][get] + +#### all + +`all(**kwargs) -> List[Optional["Model"]]` + +To get a list of related models use `all()` method. + +Note that you can filter the queryset, select related, exclude fields etc. like in normal query. + +```python # with all Queryset methods - filtering, selecting columns, counting etc. await news.posts.filter(title__contains="M2M").all() await Category.objects.filter(posts__author=guido).get() @@ -180,18 +296,135 @@ news_posts = await news.posts.select_related("author").all() assert news_posts[0].author == guido ``` -Currently supported methods are: +!!!tip + Read more in queries documentation [all][all] + +#### create + +`create(**kwargs): -> Model` + +Create related `Model` directly from parent `Model`. + +The link table is automatically populated, as well as relation ids in the database. + +```python +# Creating columns object from instance: +await post.categories.create(name="Tips") +assert len(await post.categories.all()) == 2 +# newly created instance already have relation persisted in the database +``` !!!tip - To learn more about available QuerySet methods visit [queries][queries] + Read more in queries documentation [create][create] + + +#### get_or_create + +`get_or_create(**kwargs) -> Model` + +!!!tip + Read more in queries documentation [get_or_create][get_or_create] + +#### update_or_create -##### get() -##### all() -##### filter() -##### select_related() -##### limit() -##### offset() -##### count() -##### exists() +`update_or_create(**kwargs) -> Model` -[queries]: ./queries.md \ No newline at end of file +!!!tip + Read more in queries documentation [update_or_create][update_or_create] + +#### filter + +`filter(**kwargs) -> QuerySet` + +!!!tip + Read more in queries documentation [filter][filter] + +#### exclude + +`exclude(**kwargs) -> QuerySet` + +!!!tip + Read more in queries documentation [exclude][exclude] + +#### select_related + +`select_related(related: Union[List, str]) -> QuerySet` + +!!!tip + Read more in queries documentation [select_related][select_related] + +#### prefetch_related + +`prefetch_related(related: Union[List, str]) -> QuerySet` + +!!!tip + Read more in queries documentation [prefetch_related][prefetch_related] + +#### limit + +`limit(limit_count: int) -> QuerySet` + +!!!tip + Read more in queries documentation [limit][limit] + +#### offset + +`offset(offset: int) -> QuerySet` + +!!!tip + Read more in queries documentation [offset][offset] + +#### count + +`count() -> int` + +!!!tip + Read more in queries documentation [count][count] + +#### exists + +`exists() -> bool` + +!!!tip + Read more in queries documentation [exists][exists] + +#### fields + +`fields(columns: Union[List, str, set, dict]) -> QuerySet` + +!!!tip + Read more in queries documentation [fields][fields] + +#### exclude_fields + +`exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet` + +!!!tip + Read more in queries documentation [exclude_fields][exclude_fields] + +#### order_by + +`order_by(columns:Union[List, str]) -> QuerySet` + +!!!tip + Read more in queries documentation [order_by][order_by] + + +[queries]: ./queries.md +[querysetproxy]: ./relations.md#querysetproxy-methods +[get]: ./queries.md#get +[all]: ./queries.md#all +[create]: ./queries.md#create +[get_or_create]: ./queries.md#get_or_create +[update_or_create]: ./queries.md#update_or_create +[filter]: ./queries.md#filter +[exclude]: ./queries.md#exclude +[select_related]: ./queries.md#select_related +[prefetch_related]: ./queries.md#prefetch_related +[limit]: ./queries.md#limit +[offset]: ./queries.md#offset +[count]: ./queries.md#count +[exists]: ./queries.md#exists +[fields]: ./queries.md#fields +[exclude_fields]: ./queries.md#exclude_fields +[order_by]: ./queries.md#order_by \ No newline at end of file diff --git a/docs/releases.md b/docs/releases.md index f44d9d8b4..a0af0e3eb 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,15 @@ +# 0.6.0 + +* **Breaking:** calling instance.load() when the instance row was deleted from db now raises ormar.NoMatch instead of ValueError +* **Breaking:** calling add and remove on ReverseForeignKey relation now updates the child model in db setting/removing fk column +* **Breaking:** ReverseForeignKey relation now exposes QuerySetProxy API like ManyToMany relation +* **Breaking:** querying related models from ManyToMany cleans list of related models loaded on parent model: + * Example: `post.categories.first()` will set post.categories to list of 1 related model -> the one returned by first() + * Example 2: if post has 4 categories so `len(post.categories) == 4` calling `post.categories.limit(2).all()` -> will load only 2 children and now `assert len(post.categories) == 2` +* Added `get_or_create`, `update_or_create`, `fields`, `exclude_fields`, `exclude`, `prefetch_related` and `order_by` to QuerySetProxy +so now you can use those methods directly from relation +* Update docs + # 0.5.5 * Fix for alembic autogenaration of migration `UUID` columns. It should just produce sqlalchemy `CHAR(32)` or `CHAR(36)` diff --git a/docs_src/fields/docs001.py b/docs_src/fields/docs001.py index dd7979143..d1c9144c5 100644 --- a/docs_src/fields/docs001.py +++ b/docs_src/fields/docs001.py @@ -29,7 +29,7 @@ class Meta: department: Optional[Department] = ormar.ForeignKey(Department) -department = Department(name="Science") +department = await Department(name="Science").save() course = Course(name="Math", completed=False, department=department) print(department.courses[0]) diff --git a/mkdocs.yml b/mkdocs.yml index 9c3dec678..ccaae7c05 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,5 +1,5 @@ site_name: ormar -site_description: An simple async ORM with fastapi in mind and pydantic validation. +site_description: A simple async ORM with fastapi in mind and pydantic validation. nav: - Overview: index.md - Installation: install.md diff --git a/ormar/__init__.py b/ormar/__init__.py index 73d614f69..774dea106 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -30,7 +30,7 @@ def __repr__(self) -> str: Undefined = UndefinedType() -__version__ = "0.5.5" +__version__ = "0.6.0" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/models/model.py b/ormar/models/model.py index 70d8e42a1..df406ff45 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -15,7 +15,7 @@ import sqlalchemy import ormar.queryset # noqa I100 -from ormar.exceptions import ModelPersistenceError +from ormar.exceptions import ModelPersistenceError, NoMatch from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta @@ -286,9 +286,7 @@ async def load(self: T) -> T: expr = self.Meta.table.select().where(self.pk_column == self.pk) row = await self.Meta.database.fetch_one(expr) if not row: # pragma nocover - raise ValueError( - "Instance was deleted from database and cannot be refreshed" - ) + raise NoMatch("Instance was deleted from database and cannot be refreshed") kwargs = dict(row) kwargs = self.translate_aliases_to_columns(kwargs) self.from_dict(kwargs) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 02a0566dd..f6defd48f 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -280,7 +280,9 @@ async def count(self) -> int: return await self.database.fetch_val(expr) async def update(self, each: bool = False, **kwargs: Any) -> int: - self_fields = self.model.extract_db_own_fields() + self_fields = self.model.extract_db_own_fields().union( + self.model.extract_related_names() + ) updates = {k: v for k, v in kwargs.items() if k in self_fields} updates = self.model.translate_columns_to_aliases(updates) if not each and not self.filter_clauses: diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 208b4d123..b5cad211b 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -88,7 +88,7 @@ async def exists(self) -> bool: async def count(self) -> int: return await self.queryset.count() - async def clear(self) -> int: + async def clear(self, keep_reversed: bool = True) -> int: if self.type_ == ormar.RelationType.MULTIPLE: queryset = ormar.QuerySet(model_cls=self.relation.through) owner_column = self._owner.get_name() @@ -97,10 +97,16 @@ async def clear(self) -> int: owner_column = self.related_field.name kwargs = {owner_column: self._owner} self._clean_items_on_load() + if keep_reversed and self.type_ == ormar.RelationType.REVERSE: + update_kwrgs = {f"{owner_column}": None} + return await queryset.filter(_exclude=False, **kwargs).update( + each=False, **update_kwrgs + ) return await queryset.delete(**kwargs) # type: ignore async def first(self, **kwargs: Any) -> "Model": first = await self.queryset.first(**kwargs) + self._clean_items_on_load() self._register_related(first) return first diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 25c7b1ca2..c8eb94403 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -38,17 +38,20 @@ def _check_if_queryset_is_initialized(self) -> bool: and self.queryset_proxy.queryset is not None ) - def _set_queryset(self) -> "QuerySet": - related_field = self._owner.resolve_relation_field( - self.relation.to, self._owner - ) - pkname = self._owner.get_column_alias(self._owner.Meta.pkname) + def _check_if_model_saved(self) -> None: pk_value = self._owner.pk if not pk_value: raise RelationshipInstanceError( "You cannot query relationships from unsaved model." ) - kwargs = {f"{related_field.get_alias()}__{pkname}": pk_value} + + def _set_queryset(self) -> "QuerySet": + related_field = self._owner.resolve_relation_field( + self.relation.to, self._owner + ) + pkname = self._owner.get_column_alias(self._owner.Meta.pkname) + self._check_if_model_saved() + kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk} queryset = ( ormar.QuerySet(model_cls=self.relation.to) .select_related(related_field.name) @@ -56,7 +59,9 @@ def _set_queryset(self) -> "QuerySet": ) return queryset - async def remove(self, item: "Model") -> None: # type: ignore + async def remove( # type: ignore + self, item: "Model", keep_reversed: bool = True + ) -> None: if item not in self: raise NoMatch( f"Object {self._owner.get_name()} has no " @@ -74,8 +79,11 @@ async def remove(self, item: "Model") -> None: # type: ignore if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.delete_through_instance(item) else: - setattr(item, rel_name, None) - await item.update() + if keep_reversed: + setattr(item, rel_name, None) + await item.update() + else: + await item.delete() async def add(self, item: "Model") -> None: if self.type_ == ormar.RelationType.MULTIPLE: @@ -83,6 +91,7 @@ async def add(self, item: "Model") -> None: rel_name = item.resolve_relation_name(item, self._owner) setattr(item, rel_name, self._owner) else: + self._check_if_model_saved() related_field = self._owner.resolve_relation_field( self.relation.to, self._owner ) diff --git a/tests/test_reverse_fk_queryset.py b/tests/test_reverse_fk_queryset.py index e79d55be4..dd2c49be0 100644 --- a/tests/test_reverse_fk_queryset.py +++ b/tests/test_reverse_fk_queryset.py @@ -5,6 +5,7 @@ import sqlalchemy import ormar +from ormar import NoMatch from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -190,6 +191,26 @@ async def test_getting(): assert len(tracks) == 0 assert len(album.tracks) == 0 + still_tracks = await Track.objects.all() + assert len(still_tracks) == 4 + for track in still_tracks: + assert track.album is None + + +@pytest.mark.asyncio +async def test_cleaning_related(): + async with database: + async with database.transaction(force_rollback=True): + sample_data = await get_sample_data() + album = sample_data[0] + await album.tracks.clear(keep_reversed=False) + tracks = await album.tracks.all() + assert len(tracks) == 0 + assert len(album.tracks) == 0 + + no_tracks = await Track.objects.all() + assert len(no_tracks) == 0 + @pytest.mark.asyncio async def test_loading_related(): @@ -224,12 +245,14 @@ async def test_adding_removing(): track_check = await Track.objects.get(title="Rainbow") assert track_check.album == album - track_test = await Track.objects.get(title="Rainbow") - assert track_test.album == album - await album.tracks.remove(track_new) assert track_new.album is None assert len(album.tracks) == 3 + track1 = album.tracks[0] + await album.tracks.remove(track1, keep_reversed=False) + with pytest.raises(NoMatch): + await track1.load() + track_test = await Track.objects.get(title="Rainbow") assert track_test.album is None