Skip to content

Commit

Permalink
--wip-- [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
art049 committed Aug 25, 2022
1 parent c939c22 commit 2b07058
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
25 changes: 24 additions & 1 deletion odmantic/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
Pattern,
Sequence,
Expand Down Expand Up @@ -202,7 +203,19 @@ class ODMField(ODMBaseField):

__slots__ = ("primary_field",)
__allowed_operators__ = set(
("eq", "ne", "in_", "not_in", "lt", "lte", "gt", "gte", "match", "asc", "desc")
(
"eq",
"ne",
"in_",
"not_in",
"lt",
"lte",
"gt",
"gte",
"match",
"asc",
"desc",
)
)

def __init__(
Expand Down Expand Up @@ -302,6 +315,16 @@ def __getattribute__(self, name: str) -> Any:
f"attribute {name} not found in {field.model.__name__}"
)
return FieldProxy(parent=self, field=child_field)
elif isinstance(field, ODMField) and name == "any":
if name == "any":
return FieldProxy(
parent=self,
field=ODMField(
primary_field=False,
key_name="$any",
model_config=field.model_config,
),
)

if name not in field.__allowed_operators__:
raise AttributeError(
Expand Down
21 changes: 21 additions & 0 deletions odmantic/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,27 @@ def match(field: FieldProxyAny, pattern: Union[Pattern, str]) -> QueryExpression
return QueryExpression({+field: r})


def contain(field: FieldProxyAny, expected_content: Sequence[Any]) -> QueryExpression:
"""Select the instances where `field` contains elements in `expected_content`.
!!! warning
The order of the elements is not taken into account and this will only check
that the `expected_content` are included in the actual content of the field. To
have stricter conditions, it's possible to use `eq`(==) directly.
"""
return _cmp_expression(field, "$all", expected_content)


def size(field: FieldProxyAny, length: int) -> QueryExpression:
"""Select the instances where `field` is an array of size `length`."""
return _cmp_expression(field, "$size", length)


def any_(field: FieldProxyAny, expression: QueryDictBool) -> Any:
"""Select instances where an element of `field` matches the `expression`."""
return _cmp_expression(field, "$elemMatch", expression)


class SortExpression(Dict[str, Literal[-1, 1]]):
"""Base object used to build sort queries."""

Expand Down
38 changes: 37 additions & 1 deletion tests/unit/test_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from odmantic.query import QueryExpression, SortExpression, asc
from typing import List

from odmantic.model import Model
from odmantic.query import QueryExpression, SortExpression, any_, asc, eq, gte
from tests.zoo.book_embedded import Book, Publisher
from tests.zoo.tree import TreeKind, TreeModel

Expand Down Expand Up @@ -38,3 +41,36 @@ def test_sort_repr():

def test_sort_empty_repr():
assert repr(SortExpression()) == "SortExpression()"


class ModelWithIntArray(Model):
array: List[int]


def test_array_any_eq():
expected = {"array": {"$elemMatch": {"$eq": 42}}}
assert any_(ModelWithIntArray.array, {"$eq": 42}) == expected
assert eq(any_(ModelWithIntArray.array), 42) == expected
# assert ModelWithIntArray.array.any().eq(42) == expected
# assert (ModelWithIntArray.array.any() == 42) == expected


def test_array_any_gte():
expected = {"array": {"$elemMatch": {"$gte": 42}}}
assert any_(ModelWithIntArray.array, {"$gte": 42}) == expected
# assert ModelWithIntArray.array.any().gte(42) == expected
# assert (ModelWithIntArray.array.any() >= 42) == expected


# def test_array_all_eq():
# expected = {"array": {"$all": {"$eq": 42}}}
# assert eq(ModelWithIntArray.array.all(), 42) == expected
# assert ModelWithIntArray.array.all().eq(42) == expected
# assert (ModelWithIntArray.array.all() == 42) == expected


# def test_array_all_gte():
# expected = {"array": {"$all": {"$gte": 42}}}
# assert ModelWithIntArray.array.all().gte(42) == expected
# assert gte(ModelWithIntArray.array.all(), 42) == expected
# assert (ModelWithIntArray.array.all() >= 42) == expected

0 comments on commit 2b07058

Please sign in to comment.