diff --git a/pyodmongo/async_engine/engine.py b/pyodmongo/async_engine/engine.py index a446856..6d56ec9 100644 --- a/pyodmongo/async_engine/engine.py +++ b/pyodmongo/async_engine/engine.py @@ -2,9 +2,9 @@ from pymongo.results import UpdateResult, DeleteResult from ..models.responses import SaveResponse, DeleteResponse from ..engine.utils import consolidate_dict, mount_base_pipeline -from ..services.query_operators import query_dict +from ..services.query_operators import query_dict, sort_dict from ..models.paginate import ResponsePaginate -from ..models.query_operators import LogicalOperator, ComparisonOperator +from ..models.query_operators import LogicalOperator, ComparisonOperator, SortOperator from ..models.db_model import DbModel from datetime import datetime from typing import TypeVar @@ -149,6 +149,8 @@ async def find_one( Model: type[Model], query: ComparisonOperator | LogicalOperator = None, raw_query: dict = None, + sort: SortOperator = None, + raw_sort: dict = None, populate: bool = False, as_dict: bool = False, ) -> type[Model]: @@ -158,10 +160,18 @@ async def find_one( raise TypeError( 'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument' ) + if sort and (type(sort) != SortOperator): + raise TypeError( + 'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument' + ) raw_query = {} if not raw_query else raw_query + query = query_dict(query_operator=query, dct={}) if query else raw_query + raw_sort = {} if not raw_sort else raw_sort + sort = sort_dict(sort_operators=sort) if sort else raw_sort pipeline = mount_base_pipeline( Model=Model, - query=query_dict(query_operator=query, dct={}) if query else raw_query, + query=query, + sort=sort, populate=populate, ) pipeline += [{"$limit": 1}] @@ -178,6 +188,8 @@ async def find_many( Model: type[Model], query: ComparisonOperator | LogicalOperator = None, raw_query: dict = None, + sort: SortOperator = None, + raw_sort: dict = None, populate: bool = False, as_dict: bool = False, paginate: bool = False, @@ -190,11 +202,18 @@ async def find_many( raise TypeError( 'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument' ) + if sort and (type(sort) != SortOperator): + raise TypeError( + 'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument' + ) raw_query = {} if not raw_query else raw_query query = query_dict(query_operator=query, dct={}) if query else raw_query + raw_sort = {} if not raw_sort else raw_sort + sort = sort_dict(sort_operators=sort) if sort else raw_sort pipeline = mount_base_pipeline( Model=Model, query=query, + sort=sort, populate=populate, ) if not paginate: diff --git a/pyodmongo/engine/engine.py b/pyodmongo/engine/engine.py index a2071bd..bc0d4a6 100644 --- a/pyodmongo/engine/engine.py +++ b/pyodmongo/engine/engine.py @@ -2,9 +2,9 @@ from pymongo.results import UpdateResult, DeleteResult from ..models.responses import SaveResponse, DeleteResponse from ..engine.utils import consolidate_dict, mount_base_pipeline -from ..services.query_operators import query_dict +from ..services.query_operators import query_dict, sort_dict from ..models.paginate import ResponsePaginate -from ..models.query_operators import LogicalOperator, ComparisonOperator +from ..models.query_operators import LogicalOperator, ComparisonOperator, SortOperator from ..models.db_model import DbModel from datetime import datetime from typing import TypeVar @@ -148,6 +148,8 @@ def find_one( Model: type[Model], query: ComparisonOperator | LogicalOperator = None, raw_query: dict = None, + sort: SortOperator = None, + raw_sort: dict = None, populate: bool = False, as_dict: bool = False, ) -> type[Model]: @@ -157,10 +159,18 @@ def find_one( raise TypeError( 'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument' ) + if sort and (type(sort) != SortOperator): + raise TypeError( + 'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument' + ) raw_query = {} if not raw_query else raw_query + query = query_dict(query_operator=query, dct={}) if query else raw_query + raw_sort = {} if not raw_sort else raw_sort + sort = sort_dict(sort_operators=sort) if sort else raw_sort pipeline = mount_base_pipeline( Model=Model, - query=query_dict(query_operator=query, dct={}) if query else raw_query, + query=query, + sort=sort, populate=populate, ) pipeline += [{"$limit": 1}] @@ -175,6 +185,8 @@ def find_many( Model: type[Model], query: ComparisonOperator | LogicalOperator = None, raw_query: dict = None, + sort: SortOperator = None, + raw_sort: dict = None, populate: bool = False, as_dict: bool = False, paginate: bool = False, @@ -187,11 +199,18 @@ def find_many( raise TypeError( 'query argument must be a ComparisonOperator or LogicalOperator from pyodmongo.queries. If you really need to make a very specific query, use "raw_query" argument' ) + if sort and (type(sort) != SortOperator): + raise TypeError( + 'sort argument must be a SortOperator from pyodmongo.queries. If you really need to make a very specific sort, use "raw_sort" argument' + ) raw_query = {} if not raw_query else raw_query query = query_dict(query_operator=query, dct={}) if query else raw_query + raw_sort = {} if not raw_sort else raw_sort + sort = sort_dict(sort_operators=sort) if sort else raw_sort pipeline = mount_base_pipeline( Model=Model, query=query, + sort=sort, populate=populate, ) if not paginate: diff --git a/pyodmongo/engine/utils.py b/pyodmongo/engine/utils.py index 24b37af..e43e97c 100644 --- a/pyodmongo/engine/utils.py +++ b/pyodmongo/engine/utils.py @@ -48,11 +48,12 @@ def consolidate_dict(obj: BaseModel, dct: dict): return dct -def mount_base_pipeline(Model, query, populate: bool = False): +def mount_base_pipeline(Model, query: dict, sort: dict, populate: bool = False): match_stage = [{"$match": query}] + sort_stage = [{"$sort": sort}] if sort != {} else [] model_stage = Model._pipeline reference_stage = Model._reference_pipeline if populate: - return match_stage + model_stage + reference_stage + return match_stage + model_stage + reference_stage + sort_stage else: - return match_stage + model_stage + return match_stage + model_stage + sort_stage diff --git a/pyodmongo/models/query_operators.py b/pyodmongo/models/query_operators.py index 39703f7..dcb7f69 100644 --- a/pyodmongo/models/query_operators.py +++ b/pyodmongo/models/query_operators.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +from .db_field_info import DbField from typing import Any @@ -15,3 +16,7 @@ class _LogicalOperator(BaseModel): class LogicalOperator(_LogicalOperator): operators: tuple[ComparisonOperator | _LogicalOperator, ...] + + +class SortOperator(BaseModel): + operators: tuple[tuple[DbField, int], ...] diff --git a/pyodmongo/queries/__init__.py b/pyodmongo/queries/__init__.py index 98c8dd1..5425455 100644 --- a/pyodmongo/queries/__init__.py +++ b/pyodmongo/queries/__init__.py @@ -1,3 +1,4 @@ from .comparison_operators import eq, gt, gte, in_, lt, lte, ne, nin, text +from .sort_operator import sort from .logical_operators import and_, or_, nor from .query_string import mount_query_filter diff --git a/pyodmongo/queries/sort_operator.py b/pyodmongo/queries/sort_operator.py new file mode 100644 index 0000000..18718c3 --- /dev/null +++ b/pyodmongo/queries/sort_operator.py @@ -0,0 +1,6 @@ +from ..models.db_field_info import DbField +from ..models.query_operators import SortOperator + + +def sort(*operators: tuple[DbField, int]) -> SortOperator: + return SortOperator(operators=operators) diff --git a/pyodmongo/services/query_operators.py b/pyodmongo/services/query_operators.py index cb159f4..549787a 100644 --- a/pyodmongo/services/query_operators.py +++ b/pyodmongo/services/query_operators.py @@ -1,11 +1,11 @@ -from ..models.query_operators import ComparisonOperator, LogicalOperator +from ..models.query_operators import ComparisonOperator, LogicalOperator, SortOperator def comparison_operator_dict(co: ComparisonOperator): return {co.path_str: {co.operator: co.value}} -def query_dict(query_operator: ComparisonOperator | LogicalOperator, dct: dict): +def query_dict(query_operator: ComparisonOperator | LogicalOperator, dct: dict) -> dict: if isinstance(query_operator, ComparisonOperator): return comparison_operator_dict(co=query_operator) dct[query_operator.operator] = [] @@ -17,3 +17,12 @@ def query_dict(query_operator: ComparisonOperator | LogicalOperator, dct: dict): query_dict(query_operator=operator, dct={}) ) return dct + + +def sort_dict(sort_operators: SortOperator) -> dict: + dct = {} + for operator in sort_operators.operators: + if operator[1] not in (1, -1): + raise ValueError("Only values 1 ascending and -1 descending are valid") + dct[operator[0].path_str] = operator[1] + return dct diff --git a/tests/test_async_crud_db.py b/tests/test_async_crud_db.py index 42facd7..dd4a71b 100644 --- a/tests/test_async_crud_db.py +++ b/tests/test_async_crud_db.py @@ -7,7 +7,7 @@ Id, Field, ) -from pyodmongo.queries import eq, gte, gt, mount_query_filter +from pyodmongo.queries import eq, gte, gt, mount_query_filter, sort from pyodmongo.engine.utils import consolidate_dict from pydantic import ConfigDict, BaseModel from typing import ClassVar @@ -509,8 +509,8 @@ async def drop_collections_one_three(): await db._db[ClassOne._collection].drop() await db._db[ClassThree._collection].drop() yield - # await db._db[ClassOne._collection].drop() - # await db._db[ClassThree._collection].drop() + await db._db[ClassOne._collection].drop() + await db._db[ClassThree._collection].drop() @pytest.mark.asyncio @@ -538,8 +538,49 @@ async def test_nested_list_objects(drop_collections_one_three): attr_3="obj_15", class_two_b=obj_13, class_two_b_list=[obj_13, obj_14] ) await db.save(obj_15) - from pprint import pprint obj_found = await db.find_one(Model=ClassThree, populate=True) assert obj_found.id == obj_15.id assert obj_found.class_two_b.attr_2_b == "obj_13" + + +class MySortClass(DbModel): + attr_1: str + attr_2: int + attr_3: datetime + _collection: ClassVar = "my_class_to_sort" + + +@pytest_asyncio.fixture() +async def drop_collection_for_test_sort(): + await db._db[MySortClass._collection].drop() + yield + await db._db[MySortClass._collection].drop() + + +@pytest.mark.asyncio +async def test_sort_query(drop_collection_for_test_sort): + obj_list = [ + MySortClass( + attr_1="Juliet", attr_2=100, attr_3=datetime(year=2023, month=1, day=20) + ), + MySortClass( + attr_1="Albert", attr_2=50, attr_3=datetime(year=2025, month=1, day=20) + ), + MySortClass(attr_1="Zack", attr_2=30, attr_3=datetime(year=2020, month=1, day=20)), + MySortClass( + attr_1="Charlie", attr_2=150, attr_3=datetime(year=2027, month=1, day=20) + ), + MySortClass( + attr_1="Albert", attr_2=40, attr_3=datetime(year=2025, month=1, day=20) + ), + ] + await db.save_all(obj_list=obj_list) + sort_oprator = sort((MySortClass.attr_1, 1), (MySortClass.attr_2, 1)) + result_many = await db.find_many(Model=MySortClass, sort=sort_oprator) + assert result_many[0] == obj_list[4] + assert result_many[1] == obj_list[1] + + sort_oprator = sort((MySortClass.attr_3, 1)) + result_one = await db.find_one(Model=MySortClass, sort=sort_oprator) + assert result_one == obj_list[2] diff --git a/tests/test_queries.py b/tests/test_queries.py index cc6e09f..7f12094 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -16,9 +16,10 @@ and_, or_, nor, + sort, mount_query_filter, ) -from pyodmongo.services.query_operators import query_dict +from pyodmongo.services.query_operators import query_dict, sort_dict import pytest @@ -292,3 +293,29 @@ def test_logical_operator_inside_another(): } query_dct = query_dict(query_operator=query, dct={}) assert query_dct == expected_dct + + +def test_sort_operator(): + class MyNestedClass(BaseModel): + n: int + + class MyClass(DbModel): + a: str + b: int + c: int + nested: MyNestedClass + + sort_operator = sort( + (MyClass.nested.n, -1), (MyClass.b, 1), (MyClass.c, -1), (MyClass.a, 1) + ) + assert sort_dict(sort_operators=sort_operator) == { + "nested.n": -1, + "b": 1, + "c": -1, + "a": 1, + } + + with pytest.raises( + ValueError, match="Only values 1 ascending and -1 descending are valid" + ): + sort_dict(sort_operators=sort((MyClass.a, 2))) diff --git a/tests/test_sync_crud_db.py b/tests/test_sync_crud_db.py index 9884393..2401911 100644 --- a/tests/test_sync_crud_db.py +++ b/tests/test_sync_crud_db.py @@ -7,7 +7,7 @@ ResponsePaginate, Field, ) -from pyodmongo.queries import eq, gte, gt +from pyodmongo.queries import eq, gte, gt, sort from pyodmongo.engine.utils import consolidate_dict from pydantic import ConfigDict from typing import ClassVar @@ -352,3 +352,46 @@ def test_find_as_dict(create_find_dict_collection): assert type(dct) == dict obj_dict = db.find_one(Model=AsDict2, as_dict=True, populate=True) assert type(obj_dict) == dict + + +class MySortClass(DbModel): + attr_1: str + attr_2: int + attr_3: datetime + _collection: ClassVar = "my_class_to_sort" + + +@pytest.fixture() +def drop_collection_for_test_sort(): + db._db[MySortClass._collection].drop() + yield + db._db[MySortClass._collection].drop() + + +def test_sort_query(drop_collection_for_test_sort): + obj_list = [ + MySortClass( + attr_1="Juliet", attr_2=100, attr_3=datetime(year=2023, month=1, day=20) + ), + MySortClass( + attr_1="Albert", attr_2=50, attr_3=datetime(year=2025, month=1, day=20) + ), + MySortClass( + attr_1="Zack", attr_2=30, attr_3=datetime(year=2020, month=1, day=20) + ), + MySortClass( + attr_1="Charlie", attr_2=150, attr_3=datetime(year=2027, month=1, day=20) + ), + MySortClass( + attr_1="Albert", attr_2=40, attr_3=datetime(year=2025, month=1, day=20) + ), + ] + db.save_all(obj_list=obj_list) + sort_oprator = sort((MySortClass.attr_1, 1), (MySortClass.attr_2, 1)) + result_many = db.find_many(Model=MySortClass, sort=sort_oprator) + assert result_many[0] == obj_list[4] + assert result_many[1] == obj_list[1] + + sort_oprator = sort((MySortClass.attr_3, 1)) + result_one = db.find_one(Model=MySortClass, sort=sort_oprator) + assert result_one == obj_list[2]