Skip to content

Commit

Permalink
Merge pull request #134 from mauro-andre/engine_type_hints_responses
Browse files Browse the repository at this point in the history
✏️ Type hints response in engines
  • Loading branch information
mauro-andre authored Jun 5, 2024
2 parents afd8302 + d806532 commit dfeee6b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
44 changes: 24 additions & 20 deletions pyodmongo/engines/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
from ..models.query_operators import QueryOperator
from ..models.sort_operators import SortOperator
from ..models.paginate import ResponsePaginate
from typing import TypeVar, Type, Union
from .utils import consolidate_dict, mount_base_pipeline
from ..services.verify_subclasses import is_subclass
from asyncio import gather
from math import ceil


Model = TypeVar("Model", bound=DbModel)


class _Engine:
"""
Base class for database operations, providing common functionality for both synchronous and asynchronous engines.
Expand Down Expand Up @@ -88,7 +92,7 @@ def _set_tz_info(self, tz_info: timezone):
"""
return tz_info if tz_info else self._tz_info

def _update_many_operation(self, obj: DbModel, query_dict: dict, now):
def _update_many_operation(self, obj: Type[Model], query_dict: dict, now):
"""
Create an UpdateMany operation for bulk updates.
Expand Down Expand Up @@ -132,7 +136,7 @@ def _create_delete_operations_list(

def _create_save_operations_list(
self,
objs: list[DbModel],
objs: list[Type[Model]],
query: QueryOperator,
raw_query: dict,
):
Expand All @@ -153,7 +157,7 @@ def _create_save_operations_list(
now = datetime.now(self._tz_info)
now = now.replace(microsecond=int(now.microsecond / 1000) * 1000)
for obj in objs:
obj: DbModel
obj: Model
operation = self._update_many_operation(obj=obj, query_dict=query, now=now)
collection_name = obj._collection
try:
Expand All @@ -169,7 +173,7 @@ def _create_save_operations_list(
return indexes, operations, now

def _after_save(
self, result: BulkWriteResult, objs: list[DbModel], collection_name: str, now
self, result: BulkWriteResult, objs: list[Model], collection_name: str, now
):
"""
Perform post-save operations.
Expand Down Expand Up @@ -210,7 +214,7 @@ def _db_response(self, result: BulkWriteResult):

def _aggregate_cursor(
self,
Model: DbModel,
Model: Type[Model],
pipeline,
tz_info: timezone,
):
Expand All @@ -234,7 +238,7 @@ def _aggregate_cursor(

def _aggregate_pipeline(
self,
Model: DbModel,
Model: Type[Model],
query: QueryOperator,
raw_query: dict,
sort: SortOperator,
Expand Down Expand Up @@ -311,7 +315,7 @@ def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
tz_info=tz_info,
)

async def save_all(self, obj_list: list[DbModel]) -> dict[str, DbResponse]:
async def save_all(self, obj_list: list[Model]) -> dict[str, DbResponse]:
"""
Save a list of objects to the database.
Expand All @@ -336,7 +340,7 @@ async def save_all(self, obj_list: list[DbModel]) -> dict[str, DbResponse]:
return response

async def save(
self, obj: DbModel, query: QueryOperator = None, raw_query: dict = None
self, obj: Model, query: QueryOperator = None, raw_query: dict = None
) -> DbResponse:
"""
Save a single object to the database.
Expand Down Expand Up @@ -367,15 +371,15 @@ async def save(

async def find_one(
self,
Model: DbModel,
Model: Type[Model],
query: QueryOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
raw_sort: dict = None,
populate: bool = False,
as_dict: bool = False,
tz_info: timezone = None,
) -> DbModel:
) -> Model:
"""
Find a single document in the database.
Expand Down Expand Up @@ -413,7 +417,7 @@ async def find_one(

async def find_many(
self,
Model: DbModel,
Model: Type[Model],
query: QueryOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
Expand All @@ -424,7 +428,7 @@ async def find_many(
paginate: bool = False,
current_page: int = 1,
docs_per_page: int = 1000,
) -> list[DbModel]:
) -> Union[list[Model], ResponsePaginate]:
"""
Find multiple documents in the database.
Expand Down Expand Up @@ -487,7 +491,7 @@ async def _count():

async def delete(
self,
Model: DbModel,
Model: Type[Model],
query: QueryOperator = None,
raw_query: dict = None,
delete_one: bool = False,
Expand Down Expand Up @@ -533,7 +537,7 @@ def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
tz_info=tz_info,
)

def save_all(self, obj_list: list[DbModel]) -> dict[str, DbResponse]:
def save_all(self, obj_list: list[Model]) -> dict[str, DbResponse]:
"""
Save a list of objects to the database.
Expand All @@ -558,7 +562,7 @@ def save_all(self, obj_list: list[DbModel]) -> dict[str, DbResponse]:
return response

def save(
self, obj: DbModel, query: QueryOperator = None, raw_query: dict = None
self, obj: Model, query: QueryOperator = None, raw_query: dict = None
) -> DbResponse:
"""
Save a single object to the database.
Expand Down Expand Up @@ -587,15 +591,15 @@ def save(

def find_one(
self,
Model: DbModel,
Model: Type[Model],
query: QueryOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
raw_sort: dict = None,
populate: bool = False,
as_dict: bool = False,
tz_info: timezone = None,
) -> DbModel:
) -> Model:
"""
Find a single document in the database.
Expand Down Expand Up @@ -633,7 +637,7 @@ def find_one(

def find_many(
self,
Model: DbModel,
Model: Type[Model],
query: QueryOperator = None,
raw_query: dict = None,
sort: SortOperator = None,
Expand All @@ -644,7 +648,7 @@ def find_many(
paginate: bool = False,
current_page: int = 1,
docs_per_page: int = 1000,
) -> list[DbModel]:
) -> Union[list[Model], ResponsePaginate]:
"""
Find multiple documents in the database.
Expand Down Expand Up @@ -706,7 +710,7 @@ def _count():

def delete(
self,
Model: DbModel,
Model: Type[Model],
query: QueryOperator = None,
raw_query: dict = None,
delete_one: bool = False,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class MyClass1(DbModel):

response_0: dict[str, DbResponse] = await async_engine.save_all([obj_0, obj_2])
response_1: dict[str, DbResponse] = engine.save_all([obj_1, obj_3])

assert response_0["my_class_0"].upserted_count == 1
assert response_0["my_class_0"].upserted_ids[0] == obj_0.id
assert response_0["my_class_1"].upserted_count == 1
Expand Down Expand Up @@ -117,10 +117,10 @@ class MyClass0(DbModel):
obj_to_find_0_49: MyClass0 = copy.deepcopy(objs_0_49[24])
obj_to_find_50_99: MyClass0 = copy.deepcopy(objs_50_99[24])

obj_found_0_49: MyClass0 = await async_engine.find_one(
obj_found_0_49 = await async_engine.find_one(
Model=MyClass0, query=MyClass0.id == obj_to_find_0_49.id
)
obj_found_50_49: MyClass0 = engine.find_one(
obj_found_50_49 = engine.find_one(
Model=MyClass0, query=MyClass0.id == obj_to_find_50_99.id
)

Expand Down Expand Up @@ -183,5 +183,6 @@ class MyClass0(DbModel):

find_result_0 = await async_engine.find_many(Model=MyClass0)
find_result_1 = engine.find_many(Model=MyClass0)

assert len(find_result_0) == 80
assert len(find_result_1) == 80

0 comments on commit dfeee6b

Please sign in to comment.