Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the date operator to query on date-time field. #67

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/en/docs/queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ The same special operators are also automatically added on every column.
* **neq** - Filter instances by not equal to condition.
* **startswith** - Filter instances that start with a specific value.
* **endswith** - Filter instances that end with a specific value.
* **istartswith** - Filter instances that start with a specific value, case-insensitive.
* **istartswith** - Filter instances that start with a specific value, case-insensitive.
* **iendswith** - Filter instances that end with a specific value, case-insensitive.
* **date** - Filter instances by date.

##### Example

Expand All @@ -153,6 +154,7 @@ users = await User.objects.filter(name__startswith="foo")
users = await User.objects.filter(name__istartswith="foo")
users = await User.objects.filter(name__endswith="foo")
users = await User.objects.filter(name__iendswith="foo")
users = await User.objects.filter(updated_at__date=date.today())
```

### Using
Expand Down
3 changes: 2 additions & 1 deletion mongoz/conf/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class MongozSettings(Settings):
"startswith": "startswith",
"istartswith": "istartswith",
"endswith": "endswith",
"iendswith": "iendswith"
"iendswith": "iendswith",
"date": "date",
}

def get_operator(self, name: str) -> "Expression":
Expand Down
98 changes: 78 additions & 20 deletions mongoz/core/db/querysets/core/manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timedelta
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -26,9 +27,16 @@
ORDER_EQUALITY,
VALUE_EQUALITY,
)
from mongoz.core.db.querysets.core.protocols import AwaitableQuery, MongozDocument
from mongoz.core.db.querysets.core.protocols import (
AwaitableQuery,
MongozDocument,
)
from mongoz.core.db.querysets.expressions import Expression, SortExpression
from mongoz.exceptions import DocumentNotFound, FieldDefinitionError, MultipleDocumentsReturned
from mongoz.exceptions import (
DocumentNotFound,
FieldDefinitionError,
MultipleDocumentsReturned,
)
from mongoz.protocols.queryset import QuerySetProtocol
from mongoz.utils.enums import OrderEnum

Expand Down Expand Up @@ -88,8 +96,12 @@ class registry using the database_name that provided in \
- return the self instance.
"""
manager: "Manager" = self.clone()
database = manager.model_class.meta.registry.get_database(database_name)
manager._collection = database.get_collection(manager._collection.name)._collection
database = manager.model_class.meta.registry.get_database(
database_name
)
manager._collection = database.get_collection(
manager._collection.name
)._collection
return manager

def clone(self) -> Any:
Expand All @@ -107,7 +119,9 @@ def clone(self) -> Any:

def validate_only_and_defer(self) -> None:
if self._only_fields and self._defer_fields:
raise FieldDefinitionError("You cannot use .only() and .defer() at the same time.")
raise FieldDefinitionError(
"You cannot use .only() and .defer() at the same time."
)

def get_operator(self, name: str) -> Expression:
"""
Expand All @@ -123,7 +137,9 @@ def _find_and_replace_id(self, key: str) -> str:
return cast(str, self.model_class.id.pydantic_field.alias) # type: ignore
return key

def filter_only_and_defer(self, *fields: Sequence[str], is_only: bool = False) -> "Manager":
def filter_only_and_defer(
self, *fields: Sequence[str], is_only: bool = False
) -> "Manager":
"""
Validates if should be defer or only and checks it out
"""
Expand Down Expand Up @@ -190,9 +206,15 @@ def filter_query(self, exclude: bool = False, **kwargs: Any) -> "Manager":
and value
):
asc_or_desc = lookup_operator
elif lookup_operator == OrderEnum.ASCENDING and value is False:
elif (
lookup_operator == OrderEnum.ASCENDING
and value is False
):
asc_or_desc = OrderEnum.DESCENDING
elif lookup_operator == OrderEnum.DESCENDING and value is False:
elif (
lookup_operator == OrderEnum.DESCENDING
and value is False
):
asc_or_desc = OrderEnum.ASCENDING
else:
asc_or_desc = OrderEnum.ASCENDING
Expand All @@ -207,6 +229,17 @@ def filter_query(self, exclude: bool = False, **kwargs: Any) -> "Manager":
operator = self.get_operator(lookup_operator)
expression = operator(field_name, value) # type: ignore

# For "date"
elif lookup_operator == "date":
operator = self.get_operator("gte")
from_datetime = datetime.combine(
value, datetime.min.time()
)
expression1 = operator(field_name, from_datetime) # type: ignore
clauses.append(expression1)
operator = self.get_operator("lt")
expression = operator(field_name, from_datetime + timedelta(days=1)) # type: ignore

# Add expression to the clauses
clauses.append(expression)

Expand Down Expand Up @@ -249,7 +282,9 @@ def raw(self, *values: Union[bool, Dict, Expression]) -> "Manager":
"""
manager: "Manager" = self.clone()
for value in values:
assert isinstance(value, (dict, Expression)), "Invalid argument to Raw"
assert isinstance(
value, (dict, Expression)
), "Invalid argument to Raw"
if isinstance(value, dict):
query_expressions = Expression.unpack(value)
manager._filter.extend(query_expressions)
Expand Down Expand Up @@ -290,7 +325,10 @@ def skip(self, count: int = 0) -> "Manager[T]":
return manager

def sort(
self, key: Union[Any, None] = None, direction: Union[Order, None] = None, **kwargs: Any
self,
key: Union[Any, None] = None,
direction: Union[Order, None] = None,
**kwargs: Any,
) -> "Manager[T]":
"""Sort by (key, direction) or [(key, direction)]."""
manager: "Manager" = self.clone()
Expand Down Expand Up @@ -364,7 +402,7 @@ async def _all(self) -> List[T]:
only_fields=manager._only_fields,
is_defer_fields=is_defer_fields,
defer_fields=manager._defer_fields,
from_collection=manager._collection
from_collection=manager._collection,
)
async for document in cursor
]
Expand All @@ -378,14 +416,18 @@ async def count(self, **kwargs: Any) -> int:
manager: "Manager" = self.clone()

filter_query = Expression.compile_many(manager._filter)
return cast(int, await manager._collection.count_documents(filter_query))
return cast(
int, await manager._collection.count_documents(filter_query)
)

async def create(self, **kwargs: Any) -> "Document":
"""
Creates a mongo db document.
"""
manager: "Manager" = self.clone()
instance = await manager.model_class(**kwargs).create(manager._collection)
instance = await manager.model_class(**kwargs).create(
manager._collection
)
return cast("Document", instance)

async def delete(self) -> int:
Expand Down Expand Up @@ -448,14 +490,19 @@ async def get_or_none(self, **kwargs: Any) -> Union["T", "Document", None]:
raise MultipleDocumentsReturned()
return cast(T, objects[0])

async def get_or_create(self, defaults: Union[Dict[str, Any], None] = None) -> T:
async def get_or_create(
self, defaults: Union[Dict[str, Any], None] = None
) -> T:
manager: "Manager" = self.clone()
if not defaults:
defaults = {}

data = {expression.key: expression.value for expression in manager._filter}
data = {
expression.key: expression.value for expression in manager._filter
}
defaults = {
(key if isinstance(key, str) else key._name): value for key, value in defaults.items()
(key if isinstance(key, str) else key._name): value
for key, value in defaults.items()
}

try:
Expand Down Expand Up @@ -528,12 +575,21 @@ async def update_many(self, **kwargs: Any) -> List[T]:
values = model.model_dump()

filter_query = Expression.compile_many(manager._filter)
await manager._collection.update_many(filter_query, {"$set": values})
await manager._collection.update_many(
filter_query, {"$set": values}
)

_filter = [
expression for expression in manager._filter if expression.key not in values
expression
for expression in manager._filter
if expression.key not in values
]
_filter.extend([Expression(key, "$eq", value) for key, value in values.items()])
_filter.extend(
[
Expression(key, "$eq", value)
for key, value in values.items()
]
)

manager._filter = _filter
return await manager._all()
Expand All @@ -556,7 +612,9 @@ async def bulk_update(self, **kwargs: Any) -> List[T]:
manager: "Manager" = self.clone()
return await manager.update_many(**kwargs)

async def get_document_by_id(self, id: Union[str, bson.ObjectId]) -> "Document":
async def get_document_by_id(
self, id: Union[str, bson.ObjectId]
) -> "Document":
"""
Gets a document by the id
"""
Expand Down
24 changes: 20 additions & 4 deletions tests/models/manager/test_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import date, datetime
from typing import AsyncGenerator, List, Optional

import pydantic
Expand All @@ -21,6 +22,7 @@ class Movie(Document):
year: int = mongoz.Integer()
tags: Optional[List[str]] = mongoz.Array(str, null=True)
uuid: Optional[ObjectId] = mongoz.ObjectId(null=True)
released_at: datetime = mongoz.DateTime(null=True)

class Meta:
registry = client
Expand All @@ -40,7 +42,9 @@ async def prepare_database() -> AsyncGenerator:


async def test_model_query_builder() -> None:
await Movie.objects.create(name="Downfall", year=2004)
await Movie.objects.create(
name="Downfall", year=2004, released_at=datetime.now()
)
await Movie.objects.create(name="The Two Towers", year=2002)
await Movie.objects.create(name="Casablanca", year=1942)
await Movie.objects.create(name="Gone with the wind", year=1939)
Expand Down Expand Up @@ -69,7 +73,9 @@ async def test_model_query_builder() -> None:
assert movie.name == "Downfall"
assert movie.year == 2004

movie = await Movie.objects.filter(name="Casablanca").filter(year=1942).get()
movie = (
await Movie.objects.filter(name="Casablanca").filter(year=1942).get()
)
assert movie.name == "Casablanca"
assert movie.year == 1942

Expand All @@ -81,7 +87,9 @@ async def test_model_query_builder() -> None:
assert movie.name == "Casablanca"
assert movie.year == 1942

movie = await Movie.objects.filter(year__gt=2000).filter(year__lt=2003).get()
movie = (
await Movie.objects.filter(year__gt=2000).filter(year__lt=2003).get()
)
assert movie.name == "The Two Towers"
assert movie.year == 2002

Expand Down Expand Up @@ -129,6 +137,12 @@ async def test_model_query_builder() -> None:
assert len(movies) == 1
assert movies[0].name.lower() == "gone with the Wind".lower()

movies = await Movie.objects.filter(released_at__date=date.today())
assert len(movies) == 1
assert movies[0].name == "Downfall"
assert movies[0].year == 2004


async def test_query_builder_in_list():
await Movie.objects.create(name="Downfall", year=2004)
await Movie.objects.create(name="The Two Towers", year=2002)
Expand All @@ -142,7 +156,9 @@ async def test_query_builder_in_list():
assert len(movies) == 2


@pytest.mark.parametrize("values", [{2002, 2004}, {"year": 2002}], ids=["as-set", "as-dict"])
@pytest.mark.parametrize(
"values", [{2002, 2004}, {"year": 2002}], ids=["as-set", "as-dict"]
)
async def test_query_builder_in_list_raise_assertation_error(values):
await Movie.objects.create(name="Downfall", year=2004)
await Movie.objects.create(name="The Two Towers", year=2002)
Expand Down
Loading