diff --git a/.coverage b/.coverage index 1ab8b165a..6a416ffac 100644 Binary files a/.coverage and b/.coverage differ diff --git a/README.md b/README.md index 7ed8dfe5b..ef48fe573 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,86 @@ tracks = await Track.objects.limit(1).all() assert len(tracks) == 1 ``` +Since version >=0.3 Ormar supports also many to many relationships +```python +import databases +import ormar +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + +class Author(ormar.Model): + class Meta: + tablename = "authors" + database = database + metadata = metadata + + id: ormar.Integer(primary_key=True) + first_name: ormar.String(max_length=80) + last_name: ormar.String(max_length=80) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + database = database + metadata = metadata + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=40) + + +class PostCategory(ormar.Model): + class Meta: + tablename = "posts_categories" + database = database + metadata = metadata + + +class Post(ormar.Model): + class Meta: + tablename = "posts" + database = database + metadata = metadata + + id: ormar.Integer(primary_key=True) + title: ormar.String(max_length=200) + categories: ormar.ManyToMany(Category, through=PostCategory) + author: ormar.ForeignKey(Author) + +guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") +post = await Post.objects.create(title="Hello, M2M", author=guido) +news = await Category.objects.create(name="News") + +# Add a category to a post. +await post.categories.add(news) +# or from the other end: +await news.posts.add(post) + +# Creating related object from instance: +await post.categories.create(name="Tips") +assert len(await post.categories.all()) == 2 + +# Many to many relation exposes a list of related models +# and an API of the Queryset: +assert news == await post.categories.get(name="News") + +# with all Queryset methods - filtering, selecting related, counting etc. +await news.posts.filter(title__contains="M2M").all() +await Category.objects.filter(posts__author=guido).get() + +# related models of many to many relation can be prefetched +news_posts = await news.posts.select_related("author").all() +assert news_posts[0].author == guido + +# Removal of the relationship by one +await news.posts.remove(post) +# or all at once +await news.posts.clear() + +``` + ## Data types The following keyword arguments are supported on all field types. diff --git a/ormar/__init__.py b/ormar/__init__.py index 37ebe34b0..f638111ea 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -9,13 +9,14 @@ ForeignKey, Integer, JSON, + ManyToMany, String, Text, Time, ) from ormar.models import Model -__version__ = "0.2.2" +__version__ = "0.3.0" __all__ = [ "Integer", "BigInteger", @@ -28,6 +29,7 @@ "Date", "Decimal", "Float", + "ManyToMany", "Model", "ModelDefinitionError", "ModelNotSet", diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py index f6c4dc93a..22c2665a7 100644 --- a/ormar/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -1,5 +1,6 @@ from ormar.fields.base import BaseField from ormar.fields.foreign_key import ForeignKey +from ormar.fields.many_to_many import ManyToMany from ormar.fields.model_fields import ( BigInteger, Boolean, @@ -27,5 +28,6 @@ "Float", "Time", "ForeignKey", + "ManyToMany", "BaseField", ] diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 126a3a689..f3e83a4f8 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -22,6 +22,7 @@ class BaseField: index: bool unique: bool pydantic_only: bool + virtual: bool = False default: Any server_default: Any @@ -34,8 +35,7 @@ def default_value(cls) -> Optional[Field]: default = cls.default if cls.default is not None else cls.server_default if callable(default): return Field(default_factory=default) - else: - return Field(default=default) + return Field(default=default) return None @classmethod diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index ebd4ec636..1fac75ee4 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -22,7 +22,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": return fk(**init_dict) -def ForeignKey( +def ForeignKey( # noqa CFQ002 to: Type["Model"], *, name: str = None, @@ -30,6 +30,8 @@ def ForeignKey( nullable: bool = True, related_name: str = None, virtual: bool = False, + onupdate: str = None, + ondelete: str = None, ) -> Type[object]: fk_string = to.Meta.tablename + "." + to.Meta.pkname to_field = to.__fields__[to.Meta.pkname] @@ -37,7 +39,11 @@ def ForeignKey( to=to, name=name, nullable=nullable, - constraints=[sqlalchemy.schema.ForeignKey(fk_string)], + constraints=[ + sqlalchemy.schema.ForeignKey( + fk_string, ondelete=ondelete, onupdate=onupdate + ) + ], unique=unique, column_type=to_field.type_.column_type, related_name=related_name, @@ -117,7 +123,7 @@ def expand_relationship( cls, value: Any, child: "Model", to_register: bool = True ) -> Optional[Union["Model", List["Model"]]]: if value is None: - return None + return None if not cls.virtual else [] constructors = { f"{cls.to.__name__}": cls._register_existing_model, diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py new file mode 100644 index 000000000..89d5f6288 --- /dev/null +++ b/ormar/fields/many_to_many.py @@ -0,0 +1,40 @@ +from typing import TYPE_CHECKING, Type + +from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField + +if TYPE_CHECKING: # pragma no cover + from ormar.models import Model + + +def ManyToMany( + to: Type["Model"], + through: Type["Model"], + *, + name: str = None, + unique: bool = False, + related_name: str = None, + virtual: bool = False, +) -> Type[object]: + to_field = to.__fields__[to.Meta.pkname] + namespace = dict( + to=to, + through=through, + name=name, + nullable=True, + unique=unique, + column_type=to_field.type_.column_type, + related_name=related_name, + virtual=virtual, + primary_key=False, + index=False, + pydantic_only=False, + default=None, + server_default=None, + ) + + return type("ManyToMany", (ManyToManyField, BaseField), namespace) + + +class ManyToManyField(ForeignKeyField): + through: Type["Model"] diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 8bc820d64..08475ae89 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,16 +1,18 @@ +import logging from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union import databases import pydantic import sqlalchemy from pydantic import BaseConfig -from pydantic.fields import FieldInfo +from pydantic.fields import FieldInfo, ModelField -from ormar import ForeignKey, ModelDefinitionError # noqa I100 +from ormar import ForeignKey, ModelDefinitionError, Integer # noqa I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField +from ormar.fields.many_to_many import ManyToMany, ManyToManyField from ormar.queryset import QuerySet -from ormar.relations import AliasManager +from ormar.relations.alias_manager import AliasManager if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -30,7 +32,14 @@ class ModelMeta: def register_relation_on_build(table_name: str, field: ForeignKey) -> None: - alias_manager.add_relation_type(field, table_name) + alias_manager.add_relation_type(field.to.Meta.tablename, table_name) + + +def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None: + alias_manager.add_relation_type(field.through.Meta.tablename, table_name) + alias_manager.add_relation_type( + field.through.Meta.tablename, field.to.Meta.tablename + ) def reverse_field_not_already_registered( @@ -51,17 +60,74 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: if reverse_field_not_already_registered( child, child_model_name, parent_model ): - register_reverse_model_fields(parent_model, child, child_model_name) + register_reverse_model_fields( + parent_model, child, child_model_name, model_field + ) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], +) -> None: + if issubclass(model_field, ManyToManyField): + model.Meta.model_fields[child_model_name] = ManyToMany( + child, through=model_field.through, name=child_model_name, virtual=True + ) + # register foreign keys on through model + adjust_through_many_to_many_model(model, child, model_field) + else: + model.Meta.model_fields[child_model_name] = ForeignKey( + child, name=child_model_name, virtual=True + ) + + +def adjust_through_many_to_many_model( + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] +) -> None: + model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( + model, name=model.get_name(), ondelete="CASCADE" + ) + model_field.through.Meta.model_fields[child.get_name()] = ForeignKey( + child, name=child.get_name(), ondelete="CASCADE" + ) + + create_and_append_m2m_fk(model, model_field) + create_and_append_m2m_fk(child, model_field) + + create_pydantic_field(model.get_name(), model, model_field) + create_pydantic_field(child.get_name(), child, model_field) + + +def create_pydantic_field( + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: - model.Meta.model_fields[child_model_name] = ForeignKey( - child, name=child_model_name, virtual=True + model_field.through.__fields__[field_name] = ModelField( + name=field_name, + type_=Optional[model], + model_config=model.__config__, + required=False, + class_validators=model.__validators__, ) +def create_and_append_m2m_fk( + model: Type["Model"], model_field: Type[ManyToManyField] +) -> None: + column = sqlalchemy.Column( + model.get_name(), + model.Meta.table.columns.get(model.Meta.pkname).type, + sqlalchemy.schema.ForeignKey( + model.Meta.tablename + "." + model.Meta.pkname, + ondelete="CASCADE", + onupdate="CASCADE", + ), + ) + model_field.through.Meta.columns.append(column) + model_field.through.Meta.table.append_column(column) + + def check_pk_column_validity( field_name: str, field: BaseField, pkname: str ) -> Optional[str]: @@ -77,17 +143,34 @@ def sqlalchemy_columns_from_model_fields( ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None + if len(model_fields.keys()) == 0: + model_fields["id"] = Integer(name="id", primary_key=True) + logging.warning( + "Table {table_name} had no fields so auto " + "Integer primary key named `id` created." + ) for field_name, field in model_fields.items(): if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) - if not field.pydantic_only: + if ( + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) + ): columns.append(field.get_column(field_name)) - if issubclass(field, ForeignKeyField): - register_relation_on_build(table_name, field) - + register_relation_in_alias_manager(table_name, field) return pkname, columns +def register_relation_in_alias_manager( + table_name: str, field: Type[ForeignKeyField] +) -> None: + if issubclass(field, ManyToManyField): + register_many_to_many_relation_on_build(table_name, field) + elif issubclass(field, ForeignKeyField): + register_relation_on_build(table_name, field) + + def populate_default_pydantic_field_value( type_: Type[BaseField], field: str, attrs: dict ) -> dict: @@ -109,15 +192,11 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict: return attrs -def extract_annotations_and_module( - attrs: dict, new_model: "ModelMetaclass", bases: Tuple -) -> dict: - annotations = attrs.get("__annotations__") or new_model.__annotations__ - attrs["__annotations__"] = annotations +def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: + attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get( + "__annotations__", {} + ) attrs = populate_pydantic_default_values(attrs) - - attrs["__module__"] = attrs["__module__"] or bases[0].__module__ - attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__ return attrs @@ -175,20 +254,26 @@ class Config(BaseConfig): class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: - attrs["Config"] = get_pydantic_base_orm_config() + attrs = extract_annotations_and_default_vals(attrs, bases) new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) + # breakpoint() if hasattr(new_model, "Meta"): - - attrs = extract_annotations_and_module(attrs, new_model, bases) + # attrs = extract_annotations_and_default_vals(attrs, bases) new_model = populate_meta_orm_model_fields(attrs, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model) expand_reverse_relationships(new_model) + if new_model.Meta.pkname not in attrs["__annotations__"]: + field_name = new_model.Meta.pkname + field = Integer(name=field_name, primary_key=True) + attrs["__annotations__"][field_name] = field + populate_default_pydantic_field_value(field, field_name, attrs) + new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) diff --git a/ormar/models/model.py b/ormar/models/model.py index 1b40edb0e..43a67428f 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -4,6 +4,7 @@ import sqlalchemy import ormar.queryset # noqa I100 +from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 @@ -40,10 +41,19 @@ def from_row( if select_related: related_models = group_related_list(select_related) + # breakpoint() + if ( + previous_table + and previous_table in cls.Meta.model_fields + and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) + ): + previous_table = cls.Meta.model_fields[ + previous_table + ].through.Meta.tablename + table_prefix = cls.Meta.alias_manager.resolve_relation_join( previous_table, cls.Meta.table.name ) - previous_table = cls.Meta.table.name item = cls.populate_nested_models_from_row( diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index af5b2955d..d30901c26 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -23,7 +23,8 @@ from ormar.fields.foreign_key import ForeignKeyField from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy -from ormar.relations import AliasManager, RelationsManager +from ormar.relations.alias_manager import AliasManager +from ormar.relations.relation import RelationsManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model @@ -96,14 +97,17 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.get(related), self, to_register=True ) - def __setattr__(self, name: str, value: Any) -> None: + def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 if name in self.__slots__: object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) elif name in self._orm: model = self.Meta.model_fields[name].expand_relationship(value, self) - self.__dict__[name] = model + if isinstance(self.__dict__.get(name), list): + self.__dict__[name].append(model) + else: + self.__dict__[name] = model else: value = ( self._convert_json(name, value, "dumps") @@ -115,11 +119,11 @@ def __setattr__(self, name: str, value: Any) -> None: def __getattribute__(self, item: str) -> Any: if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): return object.__getattribute__(self, item) - elif item != "_extract_related_names" and item in self._extract_related_names(): + if item != "_extract_related_names" and item in self._extract_related_names(): return self._extract_related_model_instead_of_field(item) - elif item == "pk": + if item == "pk": return self.__dict__.get(self.Meta.pkname, None) - elif item != "__fields__" and item in self.__fields__: + if item != "__fields__" and item in self.__fields__: value = self.__dict__.get(item, None) value = self._convert_json(item, value, "loads") return value @@ -131,15 +135,20 @@ def _extract_related_model_instead_of_field( if item in self._orm: return self._orm.get(item) + def __eq__(self, other: "Model") -> bool: + if isinstance(other, NewBaseModel): + return self.__same__(other) + return super().__eq__(other) # pragma no cover + def __same__(self, other: "Model") -> bool: return ( self._orm_id == other._orm_id - or self.__dict__ == other.__dict__ + or self.dict() == other.dict() or (self.pk == other.pk and self.pk is not None) ) @classmethod - def get_name(cls, title: bool = False, lower: bool = True) -> str: + def get_name(cls, lower: bool = True) -> str: name = cls.__name__ if lower: name = name.lower() diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index f7436ac8b..336db05fc 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -5,6 +5,7 @@ import ormar # noqa I100 from ormar.exceptions import QueryDefinitionError +from ormar.fields.many_to_many import ManyToManyField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -128,6 +129,10 @@ def _determine_filter_target_table( # against which the comparison is being made. previous_table = model_cls.Meta.tablename for part in related_parts: + if issubclass(model_cls.Meta.model_fields[part], ManyToManyField): + previous_table = model_cls.Meta.model_fields[ + part + ].through.Meta.tablename current_table = model_cls.Meta.model_fields[part].to.Meta.tablename manager = model_cls.Meta.alias_manager table_prefix = manager.resolve_relation_join(previous_table, current_table) diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index e216f07ca..6f450322b 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -4,8 +4,10 @@ from sqlalchemy import text import ormar # noqa I100 +from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField -from ormar.relations import AliasManager +from ormar.fields.many_to_many import ManyToManyField +from ormar.relations.alias_manager import AliasManager if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -63,6 +65,15 @@ def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: ) for part in item.split("__"): + if issubclass( + join_parameters.model_cls.Meta.model_fields[part], ManyToManyField + ): + _fields = join_parameters.model_cls.Meta.model_fields + new_part = _fields[part].to.get_name() + join_parameters = self._build_join_parameters( + part, join_parameters, is_multi=True + ) + part = new_part join_parameters = self._build_join_parameters(part, join_parameters) expr = sqlalchemy.sql.select(self.columns) @@ -83,23 +94,30 @@ def on_clause( right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") + def _is_target_relation_key( + self, field: BaseField, target_model: Type["Model"] + ) -> bool: + return issubclass(field, ForeignKeyField) and field.to.Meta == target_model.Meta + def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters, is_multi: bool = False ) -> JoinParameters: - model_cls = join_params.model_cls.Meta.model_fields[part].to + if is_multi: + model_cls = join_params.model_cls.Meta.model_fields[part].through + else: + model_cls = join_params.model_cls.Meta.model_fields[part].to to_table = model_cls.Meta.table.name alias = model_cls.Meta.alias_manager.resolve_relation_join( join_params.from_table, to_table ) if alias not in self.used_aliases: - if join_params.prev_model.Meta.model_fields[part].virtual: + if join_params.prev_model.Meta.model_fields[part].virtual or is_multi: to_key = next( ( v for k, v in model_cls.Meta.model_fields.items() - if issubclass(v, ForeignKeyField) - and v.to == join_params.prev_model + if self._is_target_relation_key(v, join_params.prev_model) ), None, ).name @@ -129,16 +147,19 @@ def _build_join_parameters( prev_model = model_cls return JoinParameters(prev_model, previous_alias, from_table, model_cls) - def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select - ) -> sqlalchemy.sql.select: + def filter(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: # noqa A003 if self.filter_clauses: if len(self.filter_clauses) == 1: clause = self.filter_clauses[0] else: clause = sqlalchemy.sql.and_(*self.filter_clauses) expr = expr.where(clause) + return expr + def _apply_expression_modifiers( + self, expr: sqlalchemy.sql.select + ) -> sqlalchemy.sql.select: + expr = self.filter(expr) if self.limit_count: expr = expr.limit(self.limit_count) diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 886c29c21..4bc5e7f25 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -48,6 +48,7 @@ def build_select_expression(self) -> sqlalchemy.sql.select: limit_count=self.limit_count, ) exp = qry.build_select_expression() + # print(exp.compile(compile_kwargs={"literal_binds": True})) return exp def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -70,7 +71,7 @@ def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": if not isinstance(related, (list, tuple)): related = [related] - related = list(self._select_related) + related + related = list(set(list(self._select_related) + related)) return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, @@ -82,13 +83,28 @@ def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": async def exists(self) -> bool: expr = self.build_select_expression() expr = sqlalchemy.exists(expr).select() + # print(expr.compile(compile_kwargs={"literal_binds": True})) return await self.database.fetch_val(expr) async def count(self) -> int: expr = self.build_select_expression().alias("subquery_for_count") expr = sqlalchemy.func.count().select().select_from(expr) + # print(expr.compile(compile_kwargs={"literal_binds": True})) return await self.database.fetch_val(expr) + async def delete(self, **kwargs: Any) -> int: + if kwargs: + return await self.filter(**kwargs).delete() + qry = Query( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + offset=self.query_offset, + limit_count=self.limit_count, + ) + expr = qry.filter(self.table.delete()) + return await self.database.execute(expr) + def limit(self, limit_count: int) -> "QuerySet": return self.__class__( model_cls=self.model_cls, @@ -118,11 +134,11 @@ async def first(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).get() + + if not self.filter_clauses: + expr = self.build_select_expression().limit(2) else: - if not self.filter_clauses: - expr = self.build_select_expression().limit(2) - else: - expr = self.build_select_expression() + expr = self.build_select_expression() rows = await self.database.fetch_all(expr) @@ -143,6 +159,7 @@ async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 return await self.filter(**kwargs).all() expr = self.build_select_expression() + # breakpoint() rows = await self.database.fetch_all(expr) result_rows = [ self.model_cls.from_row(row, select_related=self._select_related) diff --git a/ormar/relations.py b/ormar/relations.py deleted file mode 100644 index d8c597e1e..000000000 --- a/ormar/relations.py +++ /dev/null @@ -1,198 +0,0 @@ -import string -import uuid -from enum import Enum -from random import choices -from typing import List, Optional, TYPE_CHECKING, Type, Union -from weakref import proxy - -import sqlalchemy -from sqlalchemy import text - -import ormar # noqa I100 -from ormar.exceptions import RelationshipInstanceError # noqa I100 -from ormar.fields.foreign_key import ForeignKeyField # noqa I100 - - -if TYPE_CHECKING: # pragma no cover - from ormar.models import Model - - -def get_table_alias() -> str: - return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] - - -class RelationType(Enum): - PRIMARY = 1 - REVERSE = 2 - - -class AliasManager: - def __init__(self) -> None: - self._aliases = dict() - - @staticmethod - def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: - return [ - text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") - for column in table.columns - ] - - @staticmethod - def prefixed_table_name(alias: str, name: str) -> text: - return text(f"{name} {alias}_{name}") - - def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None: - if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases: - self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() - if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases: - self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() - - def resolve_relation_join(self, from_table: str, to_table: str) -> str: - return self._aliases.get(f"{from_table}_{to_table}", "") - - -class RelationProxy(list): - def __init__(self, relation: "Relation") -> None: - super(RelationProxy, self).__init__() - self.relation = relation - self._owner = self.relation.manager.owner - - def remove(self, item: "Model") -> None: - super().remove(item) - rel_name = item.resolve_relation_name(item, self._owner) - item._orm._get(rel_name).remove(self._owner) - - def append(self, item: "Model") -> None: - super().append(item) - - def add(self, item: "Model") -> None: - rel_name = item.resolve_relation_name(item, self._owner) - setattr(item, rel_name, self._owner) - - -class Relation: - def __init__(self, manager: "RelationsManager", type_: RelationType) -> None: - self.manager = manager - self._owner = manager.owner - self._type = type_ - self.related_models = ( - RelationProxy(relation=self) if type_ == RelationType.REVERSE else None - ) - - def _find_existing(self, child: "Model") -> Optional[int]: - for ind, relation_child in enumerate(self.related_models[:]): - try: - if relation_child.__same__(child): - return ind - except ReferenceError: # pragma no cover - self.related_models.pop(ind) - return None - - def add(self, child: "Model") -> None: - relation_name = self._owner.resolve_relation_name(self._owner, child) - if self._type == RelationType.PRIMARY: - self.related_models = child - self._owner.__dict__[relation_name] = child - else: - if self._find_existing(child) is None: - self.related_models.append(child) - rel = self._owner.__dict__.get(relation_name, []) - rel.append(child) - self._owner.__dict__[relation_name] = rel - - def remove(self, child: "Model") -> None: - relation_name = self._owner.resolve_relation_name(self._owner, child) - if self._type == RelationType.PRIMARY: - if self.related_models.__same__(child): - self.related_models = None - del self._owner.__dict__[relation_name] - else: - position = self._find_existing(child) - if position is not None: - self.related_models.pop(position) - del self._owner.__dict__[relation_name][position] - - def get(self) -> Union[List["Model"], "Model"]: - return self.related_models - - def __repr__(self) -> str: # pragma no cover - return str(self.related_models) - - -class RelationsManager: - def __init__( - self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None - ) -> None: - self.owner = owner - self._related_fields = related_fields or [] - self._related_names = [field.name for field in self._related_fields] - self._relations = dict() - for field in self._related_fields: - self._add_relation(field) - - def _add_relation(self, field: Type[ForeignKeyField]) -> None: - self._relations[field.name] = Relation( - manager=self, - type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE, - ) - - def __contains__(self, item: str) -> bool: - return item in self._related_names - - def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]: - relation = self._relations.get(name, None) - if relation: - return relation.get() - - def _get(self, name: str) -> Optional[Relation]: - relation = self._relations.get(name, None) - if relation: - return relation - - @staticmethod - def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: - to_field = next( - ( - field - for field in child._orm._related_fields - if field.to == parent.__class__ or field.to.Meta == parent.Meta - ), - None, - ) - - if not to_field: # pragma no cover - raise RelationshipInstanceError( - f"Model {child.__class__} does not have " - f"reference to model {parent.__class__}" - ) - - to_name = to_field.name - if virtual: - child_name, to_name = to_name, child_name or child.get_name() - child, parent = parent, proxy(child) - else: - child_name = child_name or child.get_name() + "s" - child = proxy(child) - - parent_relation = parent._orm._get(child_name) - if not parent_relation: - ormar.models.expand_reverse_relationships(child.__class__) - name = parent.resolve_relation_name(parent, child) - field = parent.Meta.model_fields[name] - parent._orm._add_relation(field) - parent_relation = parent._orm._get(child_name) - parent_relation.add(child) - child._orm._get(to_name).add(parent) - - def remove(self, name: str, child: "Model") -> None: - relation = self._get(name) - relation.remove(child) - - @staticmethod - def remove_parent(item: "Model", name: Union[str, "Model"]) -> None: - related_model = name - name = item.resolve_relation_name(item, related_model) - if name in item._orm: - relation_name = item.resolve_relation_name(related_model, item) - item._orm.remove(name, related_model) - related_model._orm.remove(relation_name, item) diff --git a/ormar/relations/__init__.py b/ormar/relations/__init__.py new file mode 100644 index 000000000..04c5468b2 --- /dev/null +++ b/ormar/relations/__init__.py @@ -0,0 +1,3 @@ +from ormar.relations.alias_manager import AliasManager + +__all__ = ["AliasManager"] diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py new file mode 100644 index 000000000..1af8c2204 --- /dev/null +++ b/ormar/relations/alias_manager.py @@ -0,0 +1,36 @@ +import string +import uuid +from random import choices +from typing import List + +import sqlalchemy +from sqlalchemy import text + + +def get_table_alias() -> str: + return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + + +class AliasManager: + def __init__(self) -> None: + self._aliases = dict() + + @staticmethod + def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: + return [ + text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") + for column in table.columns + ] + + @staticmethod + def prefixed_table_name(alias: str, name: str) -> text: + return text(f"{name} {alias}_{name}") + + def add_relation_type(self, to_table_name: str, table_name: str,) -> None: + if f"{table_name}_{to_table_name}" not in self._aliases: + self._aliases[f"{table_name}_{to_table_name}"] = get_table_alias() + if f"{to_table_name}_{table_name}" not in self._aliases: + self._aliases[f"{to_table_name}_{table_name}"] = get_table_alias() + + def resolve_relation_join(self, from_table: str, to_table: str) -> str: + return self._aliases.get(f"{from_table}_{to_table}", "") diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py new file mode 100644 index 000000000..d8599a804 --- /dev/null +++ b/ormar/relations/relation.py @@ -0,0 +1,332 @@ +from enum import Enum +from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union +from weakref import proxy + +import ormar # noqa I100 +from ormar.exceptions import RelationshipInstanceError # noqa I100 +from ormar.fields.foreign_key import ForeignKeyField # noqa I100 +from ormar.fields.many_to_many import ManyToManyField +from ormar.queryset import QuerySet + +if TYPE_CHECKING: # pragma no cover + from ormar import Model + + +class RelationType(Enum): + PRIMARY = 1 + REVERSE = 2 + MULTIPLE = 3 + + +class QuerysetProxy: + if TYPE_CHECKING: # pragma no cover + relation: "Relation" + + def __init__(self, relation: "Relation") -> None: + self.relation = relation + self.queryset = None + + def _assign_child_to_parent(self, child: "Model") -> None: + owner = self.relation._owner + rel_name = owner.resolve_relation_name(owner, child) + setattr(owner, rel_name, child) + + def _register_related(self, child: Union["Model", List["Model"]]) -> None: + if isinstance(child, list): + for subchild in child: + self._assign_child_to_parent(subchild) + else: + self._assign_child_to_parent(child) + + async def create_through_instance(self, child: "Model") -> None: + queryset = QuerySet(model_cls=self.relation.through) + owner_column = self.relation._owner.get_name() + child_column = child.get_name() + kwargs = {owner_column: self.relation._owner, child_column: child} + await queryset.create(**kwargs) + + async def delete_through_instance(self, child: "Model") -> None: + queryset = QuerySet(model_cls=self.relation.through) + owner_column = self.relation._owner.get_name() + child_column = child.get_name() + kwargs = {owner_column: self.relation._owner, child_column: child} + link_instance = await queryset.filter(**kwargs).get() + await link_instance.delete() + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + return self.queryset.filter(**kwargs) + + def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": + return self.queryset.select_related(related) + + async def exists(self) -> bool: + return await self.queryset.exists() + + async def count(self) -> int: + return await self.queryset.count() + + async def clear(self) -> int: + queryset = QuerySet(model_cls=self.relation.through) + owner_column = self.relation._owner.get_name() + kwargs = {owner_column: self.relation._owner} + return await queryset.delete(**kwargs) + + def limit(self, limit_count: int) -> "QuerySet": + return self.queryset.limit(limit_count) + + def offset(self, offset: int) -> "QuerySet": + return self.queryset.offset(offset) + + async def first(self, **kwargs: Any) -> "Model": + first = await self.queryset.first(**kwargs) + self._register_related(first) + return first + + async def get(self, **kwargs: Any) -> "Model": + get = await self.queryset.get(**kwargs) + self._register_related(get) + return get + + async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 + all_items = await self.queryset.all(**kwargs) + self._register_related(all_items) + return all_items + + async def create(self, **kwargs: Any) -> "Model": + create = await self.queryset.create(**kwargs) + self._register_related(create) + await self.create_through_instance(create) + return create + + +class RelationProxy(list): + def __init__(self, relation: "Relation") -> None: + super(RelationProxy, self).__init__() + self.relation = relation + self._owner = self.relation.manager.owner + self.queryset_proxy = QuerysetProxy(relation=self.relation) + + def __getattribute__(self, item: str) -> Any: + if item in ["count", "clear"]: + if not self.queryset_proxy.queryset: + self.queryset_proxy.queryset = self._set_queryset() + return getattr(self.queryset_proxy, item) + return super().__getattribute__(item) + + def __getattr__(self, item: str) -> Any: + if not self.queryset_proxy.queryset: + self.queryset_proxy.queryset = self._set_queryset() + return getattr(self.queryset_proxy, item) + + def _set_queryset(self) -> QuerySet: + owner_table = self.relation._owner.Meta.tablename + pkname = self.relation._owner.Meta.pkname + pk_value = self.relation._owner.pk + if not pk_value: + raise RelationshipInstanceError( + "You cannot query many to many relationship on unsaved model." + ) + kwargs = {f"{owner_table}__{pkname}": pk_value} + queryset = ( + QuerySet(model_cls=self.relation.to) + .select_related(owner_table) + .filter(**kwargs) + ) + return queryset + + async def remove(self, item: "Model") -> None: + super().remove(item) + rel_name = item.resolve_relation_name(item, self._owner) + item._orm._get(rel_name).remove(self._owner) + if self.relation._type == RelationType.MULTIPLE: + await self.queryset_proxy.delete_through_instance(item) + + def append(self, item: "Model") -> None: + super().append(item) + + async def add(self, item: "Model") -> None: + if self.relation._type == RelationType.MULTIPLE: + await self.queryset_proxy.create_through_instance(item) + rel_name = item.resolve_relation_name(item, self._owner) + setattr(item, rel_name, self._owner) + + +class Relation: + def __init__( + self, + manager: "RelationsManager", + type_: RelationType, + to: Type["Model"], + through: Type["Model"] = None, + ) -> None: + self.manager = manager + self._owner = manager.owner + self._type = type_ + self.to = to + self.through = through + self.related_models = ( + RelationProxy(relation=self) + if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) + else None + ) + + def _find_existing(self, child: "Model") -> Optional[int]: + for ind, relation_child in enumerate(self.related_models[:]): + try: + if relation_child.__same__(child): + return ind + except ReferenceError: # pragma no cover + self.related_models.pop(ind) + return None + + def add(self, child: "Model") -> None: + relation_name = self._owner.resolve_relation_name(self._owner, child) + if self._type == RelationType.PRIMARY: + self.related_models = child + self._owner.__dict__[relation_name] = child + else: + if self._find_existing(child) is None: + self.related_models.append(child) + rel = self._owner.__dict__.get(relation_name, []) + rel = rel or [] + if not isinstance(rel, list): + rel = [rel] + rel.append(child) + self._owner.__dict__[relation_name] = rel + + def remove(self, child: "Model") -> None: + relation_name = self._owner.resolve_relation_name(self._owner, child) + if self._type == RelationType.PRIMARY: + if self.related_models.__same__(child): + self.related_models = None + del self._owner.__dict__[relation_name] + else: + position = self._find_existing(child) + if position is not None: + self.related_models.pop(position) + del self._owner.__dict__[relation_name][position] + + def get(self) -> Union[List["Model"], "Model"]: + return self.related_models + + def __repr__(self) -> str: # pragma no cover + return str(self.related_models) + + +class RelationsManager: + def __init__( + self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None + ) -> None: + self.owner = proxy(owner) + self._related_fields = related_fields or [] + self._related_names = [field.name for field in self._related_fields] + self._relations = dict() + for field in self._related_fields: + self._add_relation(field) + + def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType: + if issubclass(field, ManyToManyField): + return RelationType.MULTIPLE + return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE + + def _add_relation(self, field: Type[ForeignKeyField]) -> None: + self._relations[field.name] = Relation( + manager=self, + type_=self._get_relation_type(field), + to=field.to, + through=getattr(field, "through", None), + ) + + def __contains__(self, item: str) -> bool: + return item in self._related_names + + def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]: + relation = self._relations.get(name, None) + if relation is not None: + return relation.get() + + def _get(self, name: str) -> Optional[Relation]: + relation = self._relations.get(name, None) + if relation is not None: + return relation + + @staticmethod + def register_missing_relation( + parent: "Model", child: "Model", child_name: str + ) -> Relation: + ormar.models.expand_reverse_relationships(child.__class__) + name = parent.resolve_relation_name(parent, child) + field = parent.Meta.model_fields[name] + parent._orm._add_relation(field) + parent_relation = parent._orm._get(child_name) + return parent_relation + + @staticmethod + def get_relations_sides_and_names( + to_field: Type[ForeignKeyField], + parent: "Model", + child: "Model", + child_name: str, + virtual: bool, + ) -> Tuple["Model", "Model", str, str]: + to_name = to_field.name + if issubclass(to_field, ManyToManyField): + child_name, to_name = ( + child.resolve_relation_name(parent, child), + child.resolve_relation_name(child, parent), + ) + child = proxy(child) + elif virtual: + child_name, to_name = to_name, child_name or child.get_name() + child, parent = parent, proxy(child) + else: + child_name = child_name or child.get_name() + "s" + child = proxy(child) + return parent, child, child_name, to_name + + @staticmethod + def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: + to_field = next( + ( + field + for field in child._orm._related_fields + if field.to == parent.__class__ or field.to.Meta == parent.Meta + ), + None, + ) + + if not to_field: # pragma no cover + raise RelationshipInstanceError( + f"Model {child.__class__} does not have " + f"reference to model {parent.__class__}" + ) + + ( + parent, + child, + child_name, + to_name, + ) = RelationsManager.get_relations_sides_and_names( + to_field, parent, child, child_name, virtual + ) + + parent_relation = parent._orm._get(child_name) + if not parent_relation: + parent_relation = RelationsManager.register_missing_relation( + parent, child, child_name + ) + parent_relation.add(child) + child._orm._get(to_name).add(parent) + + def remove(self, name: str, child: "Model") -> None: + relation = self._get(name) + relation.remove(child) + + @staticmethod + def remove_parent(item: "Model", name: Union[str, "Model"]) -> None: + related_model = name + name = item.resolve_relation_name(item, related_model) + if name in item._orm: + relation_name = item.resolve_relation_name(related_model, item) + item._orm.remove(name, related_model) + related_model._orm.remove(relation_name, item) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index fb85fc7d5..98a7d5b15 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,5 +1,3 @@ -import gc - import databases import pytest import sqlalchemy @@ -179,7 +177,7 @@ async def test_model_removal_from_relations(): await track3.save() assert len(album.tracks) == 3 - album.tracks.remove(track1) + await album.tracks.remove(track1) assert len(album.tracks) == 2 assert track1.album is None @@ -187,7 +185,7 @@ async def test_model_removal_from_relations(): track1 = await Track.objects.get(title="The Birdman") assert track1.album is None - album.tracks.add(track1) + await album.tracks.add(track1) assert len(album.tracks) == 3 assert track1.album == album diff --git a/tests/test_many_to_many.py b/tests/test_many_to_many.py new file mode 100644 index 000000000..a2963e805 --- /dev/null +++ b/tests/test_many_to_many.py @@ -0,0 +1,178 @@ +import databases +import pytest +import sqlalchemy + +import ormar +from ormar.exceptions import RelationshipInstanceError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Author(ormar.Model): + class Meta: + tablename = "authors" + database = database + metadata = metadata + + id: ormar.Integer(primary_key=True) + first_name: ormar.String(max_length=80) + last_name: ormar.String(max_length=80) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + database = database + metadata = metadata + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=40) + + +class PostCategory(ormar.Model): + class Meta: + tablename = "posts_categories" + database = database + metadata = metadata + + +class Post(ormar.Model): + class Meta: + tablename = "posts" + database = database + metadata = metadata + + id: ormar.Integer(primary_key=True) + title: ormar.String(max_length=200) + categories: ormar.ManyToMany(Category, through=PostCategory) + author: ormar.ForeignKey(Author) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.fixture(scope="function") +async def cleanup(): + yield + await PostCategory.objects.delete() + await Post.objects.delete() + await Category.objects.delete() + await Author.objects.delete() + + +@pytest.mark.asyncio +async def test_assigning_related_objects(cleanup): + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + + # Add a category to a post. + await post.categories.add(news) + # or from the other end: + await news.posts.add(post) + + # Creating related object from instance: + await post.categories.create(name="Tips") + assert len(post.categories) == 2 + + post_categories = await post.categories.all() + assert len(post_categories) == 2 + + +@pytest.mark.asyncio +async def test_quering_of_the_m2m_models(cleanup): + # orm can do this already. + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + # tl;dr: `post.categories` exposes the QuerySet API. + + await post.categories.add(news) + + post_categories = await post.categories.all() + assert len(post_categories) == 1 + + assert news == await post.categories.get(name="News") + + num_posts = await news.posts.count() + assert num_posts == 1 + + posts_about_m2m = await news.posts.filter(title__contains="M2M").all() + assert len(posts_about_m2m) == 1 + assert posts_about_m2m[0] == post + posts_about_python = await Post.objects.filter(categories__name="python").all() + assert len(posts_about_python) == 0 + + # Traversal of relationships: which categories has Guido contributed to? + category = await Category.objects.filter(posts__author=guido).get() + assert category == news + # or: + category2 = await Category.objects.filter(posts__author__first_name="Guido").get() + assert category2 == news + + +@pytest.mark.asyncio +async def test_removal_of_the_relations(cleanup): + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + await post.categories.add(news) + assert len(await post.categories.all()) == 1 + await post.categories.remove(news) + assert len(await post.categories.all()) == 0 + # or: + await news.posts.add(post) + assert len(await news.posts.all()) == 1 + await news.posts.remove(post) + assert len(await news.posts.all()) == 0 + + # Remove all related objects: + await post.categories.add(news) + await post.categories.clear() + assert len(await post.categories.all()) == 0 + + # post would also lose 'news' category when running: + await post.categories.add(news) + await news.delete() + assert len(await post.categories.all()) == 0 + + +@pytest.mark.asyncio +async def test_selecting_related(cleanup): + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + recent = await Category.objects.create(name="Recent") + await post.categories.add(news) + await post.categories.add(recent) + assert len(await post.categories.all()) == 2 + # Loads categories and posts (2 queries) and perform the join in Python. + categories = await Category.objects.select_related("posts").all() + # No extra queries needed => no more `await`s required. + for category in categories: + assert category.posts[0] == post + + news_posts = await news.posts.select_related("author").all() + assert news_posts[0].author == guido + + assert (await post.categories.limit(1).all())[0] == news + assert (await post.categories.offset(1).limit(1).all())[0] == recent + + assert await post.categories.first() == news + + assert await post.categories.exists() + + +@pytest.mark.asyncio +async def test_selecting_related_fail_without_saving(cleanup): + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = Post(title="Hello, M2M", author=guido) + with pytest.raises(RelationshipInstanceError): + await post.categories.all()