Skip to content

Commit

Permalink
✨ find_one and find_many with sort
Browse files Browse the repository at this point in the history
  • Loading branch information
mauro-andre committed Apr 10, 2024
1 parent f796dcf commit 825471d
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 17 deletions.
25 changes: 22 additions & 3 deletions pyodmongo/async_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from pymongo.results import UpdateResult, DeleteResult
from ..models.responses import SaveResponse, DeleteResponse
from ..engine.utils import consolidate_dict, mount_base_pipeline
from ..services.query_operators import query_dict
from ..services.query_operators import query_dict, sort_dict
from ..models.paginate import ResponsePaginate
from ..models.query_operators import LogicalOperator, ComparisonOperator
from ..models.query_operators import LogicalOperator, ComparisonOperator, SortOperator
from ..models.db_model import DbModel
from datetime import datetime
from typing import TypeVar
Expand Down Expand Up @@ -149,6 +149,8 @@ async def find_one(
Model: type[Model],
query: ComparisonOperator | LogicalOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
raw_sort: dict = None,
populate: bool = False,
as_dict: bool = False,
) -> type[Model]:
Expand All @@ -158,10 +160,18 @@ async def find_one(
raise TypeError(
'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument'
)
if sort and (type(sort) != SortOperator):
raise TypeError(
'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument'
)
raw_query = {} if not raw_query else raw_query
query = query_dict(query_operator=query, dct={}) if query else raw_query
raw_sort = {} if not raw_sort else raw_sort
sort = sort_dict(sort_operators=sort) if sort else raw_sort
pipeline = mount_base_pipeline(
Model=Model,
query=query_dict(query_operator=query, dct={}) if query else raw_query,
query=query,
sort=sort,
populate=populate,
)
pipeline += [{"$limit": 1}]
Expand All @@ -178,6 +188,8 @@ async def find_many(
Model: type[Model],
query: ComparisonOperator | LogicalOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
raw_sort: dict = None,
populate: bool = False,
as_dict: bool = False,
paginate: bool = False,
Expand All @@ -190,11 +202,18 @@ async def find_many(
raise TypeError(
'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument'
)
if sort and (type(sort) != SortOperator):
raise TypeError(
'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument'
)
raw_query = {} if not raw_query else raw_query
query = query_dict(query_operator=query, dct={}) if query else raw_query
raw_sort = {} if not raw_sort else raw_sort
sort = sort_dict(sort_operators=sort) if sort else raw_sort
pipeline = mount_base_pipeline(
Model=Model,
query=query,
sort=sort,
populate=populate,
)
if not paginate:
Expand Down
25 changes: 22 additions & 3 deletions pyodmongo/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from pymongo.results import UpdateResult, DeleteResult
from ..models.responses import SaveResponse, DeleteResponse
from ..engine.utils import consolidate_dict, mount_base_pipeline
from ..services.query_operators import query_dict
from ..services.query_operators import query_dict, sort_dict
from ..models.paginate import ResponsePaginate
from ..models.query_operators import LogicalOperator, ComparisonOperator
from ..models.query_operators import LogicalOperator, ComparisonOperator, SortOperator
from ..models.db_model import DbModel
from datetime import datetime
from typing import TypeVar
Expand Down Expand Up @@ -148,6 +148,8 @@ def find_one(
Model: type[Model],
query: ComparisonOperator | LogicalOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
raw_sort: dict = None,
populate: bool = False,
as_dict: bool = False,
) -> type[Model]:
Expand All @@ -157,10 +159,18 @@ def find_one(
raise TypeError(
'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument'
)
if sort and (type(sort) != SortOperator):
raise TypeError(
'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument'
)
raw_query = {} if not raw_query else raw_query
query = query_dict(query_operator=query, dct={}) if query else raw_query
raw_sort = {} if not raw_sort else raw_sort
sort = sort_dict(sort_operators=sort) if sort else raw_sort
pipeline = mount_base_pipeline(
Model=Model,
query=query_dict(query_operator=query, dct={}) if query else raw_query,
query=query,
sort=sort,
populate=populate,
)
pipeline += [{"$limit": 1}]
Expand All @@ -175,6 +185,8 @@ def find_many(
Model: type[Model],
query: ComparisonOperator | LogicalOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
raw_sort: dict = None,
populate: bool = False,
as_dict: bool = False,
paginate: bool = False,
Expand All @@ -187,11 +199,18 @@ def find_many(
raise TypeError(
'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument'
)
if sort and (type(sort) != SortOperator):
raise TypeError(
'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument'
)
raw_query = {} if not raw_query else raw_query
query = query_dict(query_operator=query, dct={}) if query else raw_query
raw_sort = {} if not raw_sort else raw_sort
sort = sort_dict(sort_operators=sort) if sort else raw_sort
pipeline = mount_base_pipeline(
Model=Model,
query=query,
sort=sort,
populate=populate,
)
if not paginate:
Expand Down
7 changes: 4 additions & 3 deletions pyodmongo/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ def consolidate_dict(obj: BaseModel, dct: dict):
return dct


def mount_base_pipeline(Model, query, populate: bool = False):
def mount_base_pipeline(Model, query: dict, sort: dict, populate: bool = False):
match_stage = [{"$match": query}]
sort_stage = [{"$sort": sort}] if sort != {} else []
model_stage = Model._pipeline
reference_stage = Model._reference_pipeline
if populate:
return match_stage + model_stage + reference_stage
return match_stage + model_stage + reference_stage + sort_stage
else:
return match_stage + model_stage
return match_stage + model_stage + sort_stage
5 changes: 5 additions & 0 deletions pyodmongo/models/query_operators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel
from .db_field_info import DbField
from typing import Any


Expand All @@ -15,3 +16,7 @@ class _LogicalOperator(BaseModel):

class LogicalOperator(_LogicalOperator):
operators: tuple[ComparisonOperator | _LogicalOperator, ...]


class SortOperator(BaseModel):
operators: tuple[tuple[DbField, int], ...]
1 change: 1 addition & 0 deletions pyodmongo/queries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .comparison_operators import eq, gt, gte, in_, lt, lte, ne, nin, text
from .sort_operator import sort
from .logical_operators import and_, or_, nor
from .query_string import mount_query_filter
6 changes: 6 additions & 0 deletions pyodmongo/queries/sort_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ..models.db_field_info import DbField
from ..models.query_operators import SortOperator


def sort(*operators: tuple[DbField, int]) -> SortOperator:
return SortOperator(operators=operators)
13 changes: 11 additions & 2 deletions pyodmongo/services/query_operators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ..models.query_operators import ComparisonOperator, LogicalOperator
from ..models.query_operators import ComparisonOperator, LogicalOperator, SortOperator


def comparison_operator_dict(co: ComparisonOperator):
return {co.path_str: {co.operator: co.value}}


def query_dict(query_operator: ComparisonOperator | LogicalOperator, dct: dict):
def query_dict(query_operator: ComparisonOperator | LogicalOperator, dct: dict) -> dict:
if isinstance(query_operator, ComparisonOperator):
return comparison_operator_dict(co=query_operator)
dct[query_operator.operator] = []
Expand All @@ -17,3 +17,12 @@ def query_dict(query_operator: ComparisonOperator | LogicalOperator, dct: dict):
query_dict(query_operator=operator, dct={})
)
return dct


def sort_dict(sort_operators: SortOperator) -> dict:
dct = {}
for operator in sort_operators.operators:
if operator[1] not in (1, -1):
raise ValueError("Only values 1 ascending and -1 descending are valid")
dct[operator[0].path_str] = operator[1]
return dct
49 changes: 45 additions & 4 deletions tests/test_async_crud_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Id,
Field,
)
from pyodmongo.queries import eq, gte, gt, mount_query_filter
from pyodmongo.queries import eq, gte, gt, mount_query_filter, sort
from pyodmongo.engine.utils import consolidate_dict
from pydantic import ConfigDict, BaseModel
from typing import ClassVar
Expand Down Expand Up @@ -509,8 +509,8 @@ async def drop_collections_one_three():
await db._db[ClassOne._collection].drop()
await db._db[ClassThree._collection].drop()
yield
# await db._db[ClassOne._collection].drop()
# await db._db[ClassThree._collection].drop()
await db._db[ClassOne._collection].drop()
await db._db[ClassThree._collection].drop()


@pytest.mark.asyncio
Expand Down Expand Up @@ -538,8 +538,49 @@ async def test_nested_list_objects(drop_collections_one_three):
attr_3="obj_15", class_two_b=obj_13, class_two_b_list=[obj_13, obj_14]
)
await db.save(obj_15)
from pprint import pprint

obj_found = await db.find_one(Model=ClassThree, populate=True)
assert obj_found.id == obj_15.id
assert obj_found.class_two_b.attr_2_b == "obj_13"


class MySortClass(DbModel):
attr_1: str
attr_2: int
attr_3: datetime
_collection: ClassVar = "my_class_to_sort"


@pytest_asyncio.fixture()
async def drop_collection_for_test_sort():
await db._db[MySortClass._collection].drop()
yield
await db._db[MySortClass._collection].drop()


@pytest.mark.asyncio
async def test_sort_query(drop_collection_for_test_sort):
obj_list = [
MySortClass(
attr_1="Juliet", attr_2=100, attr_3=datetime(year=2023, month=1, day=20)
),
MySortClass(
attr_1="Albert", attr_2=50, attr_3=datetime(year=2025, month=1, day=20)
),
MySortClass(attr_1="Zack", attr_2=30, attr_3=datetime(year=2020, month=1, day=20)),
MySortClass(
attr_1="Charlie", attr_2=150, attr_3=datetime(year=2027, month=1, day=20)
),
MySortClass(
attr_1="Albert", attr_2=40, attr_3=datetime(year=2025, month=1, day=20)
),
]
await db.save_all(obj_list=obj_list)
sort_oprator = sort((MySortClass.attr_1, 1), (MySortClass.attr_2, 1))
result_many = await db.find_many(Model=MySortClass, sort=sort_oprator)
assert result_many[0] == obj_list[4]
assert result_many[1] == obj_list[1]

sort_oprator = sort((MySortClass.attr_3, 1))
result_one = await db.find_one(Model=MySortClass, sort=sort_oprator)
assert result_one == obj_list[2]
29 changes: 28 additions & 1 deletion tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
and_,
or_,
nor,
sort,
mount_query_filter,
)
from pyodmongo.services.query_operators import query_dict
from pyodmongo.services.query_operators import query_dict, sort_dict
import pytest


Expand Down Expand Up @@ -292,3 +293,29 @@ def test_logical_operator_inside_another():
}
query_dct = query_dict(query_operator=query, dct={})
assert query_dct == expected_dct


def test_sort_operator():
class MyNestedClass(BaseModel):
n: int

class MyClass(DbModel):
a: str
b: int
c: int
nested: MyNestedClass

sort_operator = sort(
(MyClass.nested.n, -1), (MyClass.b, 1), (MyClass.c, -1), (MyClass.a, 1)
)
assert sort_dict(sort_operators=sort_operator) == {
"nested.n": -1,
"b": 1,
"c": -1,
"a": 1,
}

with pytest.raises(
ValueError, match="Only values 1 ascending and -1 descending are valid"
):
sort_dict(sort_operators=sort((MyClass.a, 2)))
45 changes: 44 additions & 1 deletion tests/test_sync_crud_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ResponsePaginate,
Field,
)
from pyodmongo.queries import eq, gte, gt
from pyodmongo.queries import eq, gte, gt, sort
from pyodmongo.engine.utils import consolidate_dict
from pydantic import ConfigDict
from typing import ClassVar
Expand Down Expand Up @@ -352,3 +352,46 @@ def test_find_as_dict(create_find_dict_collection):
assert type(dct) == dict
obj_dict = db.find_one(Model=AsDict2, as_dict=True, populate=True)
assert type(obj_dict) == dict


class MySortClass(DbModel):
attr_1: str
attr_2: int
attr_3: datetime
_collection: ClassVar = "my_class_to_sort"


@pytest.fixture()
def drop_collection_for_test_sort():
db._db[MySortClass._collection].drop()
yield
db._db[MySortClass._collection].drop()


def test_sort_query(drop_collection_for_test_sort):
obj_list = [
MySortClass(
attr_1="Juliet", attr_2=100, attr_3=datetime(year=2023, month=1, day=20)
),
MySortClass(
attr_1="Albert", attr_2=50, attr_3=datetime(year=2025, month=1, day=20)
),
MySortClass(
attr_1="Zack", attr_2=30, attr_3=datetime(year=2020, month=1, day=20)
),
MySortClass(
attr_1="Charlie", attr_2=150, attr_3=datetime(year=2027, month=1, day=20)
),
MySortClass(
attr_1="Albert", attr_2=40, attr_3=datetime(year=2025, month=1, day=20)
),
]
db.save_all(obj_list=obj_list)
sort_oprator = sort((MySortClass.attr_1, 1), (MySortClass.attr_2, 1))
result_many = db.find_many(Model=MySortClass, sort=sort_oprator)
assert result_many[0] == obj_list[4]
assert result_many[1] == obj_list[1]

sort_oprator = sort((MySortClass.attr_3, 1))
result_one = db.find_one(Model=MySortClass, sort=sort_oprator)
assert result_one == obj_list[2]

0 comments on commit 825471d

Please sign in to comment.