Skip to content

Commit

Permalink
--wip-- [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
art049 committed Oct 9, 2022
1 parent fe746d0 commit 97b3852
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 15 deletions.
22 changes: 21 additions & 1 deletion odmantic/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@
from pymongo.command_cursor import CommandCursor
from pymongo.database import Database

from odmantic.exceptions import DocumentNotFoundError, DuplicateKeyError
from odmantic.exceptions import (
DocumentNotFoundError,
DuplicateKeyError,
ReferencedDocumentNotFoundError,
ReferenceNotFoundError,
)
from odmantic.field import FieldProxy, ODMReference
from odmantic.index import ODMBaseIndex
from odmantic.model import Model
from odmantic.query import QueryExpression, SortExpression, and_
from odmantic.reference import Reference
from odmantic.session import (
AIOSession,
AIOSessionBase,
Expand Down Expand Up @@ -511,6 +517,20 @@ async def find_one(
return None
return results[0]

async def resolve(
self,
ref: Reference[ModelType],
session: AIOSessionType = None,
) -> ModelType:
result = await self.find_one(
ref.model,
getattr(ref.model, ref.model.__primary_field__) == ref.pointer,
session=session,
)
if result is None:
raise ReferenceNotFoundError(ref.model, ref.pointer)
return result

async def _save(
self, instance: ModelType, session: "AsyncIOMotorClientSession"
) -> ModelType:
Expand Down
18 changes: 18 additions & 0 deletions odmantic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ def __init__(
)


class ReferenceNotFoundError(BaseEngineException):
"""The referenced document has not been found by the engine.
Attributes:
model: the referenced model class that has not been found
pointer: the pointer to the referenced document (value of the primary field)
"""

def __init__(self, model: Type["Model"], pointer: Any):
self.model = model
self.pointer = pointer
super().__init__(
f"Document not found for : {model.__name__}. "
f"Reference details: {model.__primary_field__} -> {pointer}",
model,
)


ErrorList = List[Union[Sequence[Any], ErrorWrapper]]


Expand Down
59 changes: 48 additions & 11 deletions odmantic/reference.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from typing import Any, Optional
from enum import Enum
from typing import Any, Generic, Optional, Type, TypeVar

from pydantic.fields import Undefined

def Reference(*, key_name: Optional[str] = None) -> Any:
"""Used to define reference fields.
from odmantic.typing import Annotated

Args:
key_name: name of the Mongo key that stores the foreign key
<!--
#noqa: DAR201
-->
"""
return ODMReferenceInfo(key_name=key_name)
ModelType = TypeVar("ModelType")


class ODMReferenceInfo:
Expand All @@ -21,3 +15,46 @@ class ODMReferenceInfo:

def __init__(self, key_name: Optional[str]):
self.key_name = key_name


class ReferenceMode(Enum):
EAGER = "EAGER"
LAZY = "LAZY"
QUERY = "QUERY"


EagerReference = Annotated[ModelType, ReferenceMode.EAGER]
LazyReference = Annotated[ModelType, ReferenceMode.LAZY]
Reference = Annotated[ModelType, ReferenceMode.QUERY]


class ReferenceProxy(
Generic[ModelType],
):
__instance__: Optional[ModelType] = None
__pointer__: Any
"""Used to define reference fields.
Args:
key_name: optional name of the Mongo key that stores the foreign key
<!--
#noqa: DAR201
-->
"""

def __init__(self, *, key_name: Optional[str] = None) -> None:
self.__reference_key_name__ = key_name

def __resolve__(self, instance: ModelType) -> None:
self.__instance__ = instance

def __getattribute__(self, __name: str) -> Any:
if self.__instance__ is None:
raise AttributeError(
"Cannot access attribute of a LazyReference before instance resolution"
)
instance_attr = getattr(self.instance, __name, Undefined)
if instance_attr is not Undefined:
return instance_attr
return super().__getattribute__(__name)
4 changes: 2 additions & 2 deletions odmantic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

# Handles globally the typing imports from typing or the typing_extensions backport
if sys.version_info < (3, 8):
from typing_extensions import Literal, get_args, get_origin
from typing_extensions import Annotated, Literal, get_args, get_origin
else:
from typing import Literal, get_args, get_origin # noqa: F401
from typing import Annotated, Literal, get_args, get_origin # noqa: F401

if sys.version_info < (3, 11):
from typing_extensions import dataclass_transform
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from odmantic.field import Field
from odmantic.model import Model
from odmantic.reference import Reference
from odmantic.reference import EagerReference, LazyReference, Reference


def test_build_query_filter_across_reference():
Expand Down Expand Up @@ -39,3 +39,14 @@ class M(Model):
r = Referenced(key=1)
m = M(ref=r)
assert m.doc()["ref"] == 1


def test_eager_ref():
class Referenced(Model):
f: int

class M(Model):
ref: EagerReference[Referenced]
ref2: LazyReference[Referenced]

M.ref2.f

0 comments on commit 97b3852

Please sign in to comment.