From f80d11a72b8cf8ce0c2d5bdebb837b6966609164 Mon Sep 17 00:00:00 2001 From: Mauro Andre Date: Fri, 24 May 2024 19:19:39 -0300 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20save?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyodmongo/__init__.py | 2 +- pyodmongo/engines/engines.py | 59 ++++++++++++++++++++++++++++++----- pyodmongo/models/responses.py | 10 ++++++ tests/test_engines.py | 49 ++++++++++++++++++++++++++--- 4 files changed, 107 insertions(+), 13 deletions(-) diff --git a/pyodmongo/__init__.py b/pyodmongo/__init__.py index a86e908..76c6241 100644 --- a/pyodmongo/__init__.py +++ b/pyodmongo/__init__.py @@ -4,5 +4,5 @@ from .models.db_model import DbModel, MainBaseModel from .models.id_model import Id from .models.paginate import ResponsePaginate -from .models.responses import SaveResponse, DeleteResponse +from .models.responses import DbResponse, SaveResponse, DeleteResponse from .models.fields import Field diff --git a/pyodmongo/engines/engines.py b/pyodmongo/engines/engines.py index af47b4c..9e5024a 100644 --- a/pyodmongo/engines/engines.py +++ b/pyodmongo/engines/engines.py @@ -4,6 +4,8 @@ from datetime import datetime, timezone, UTC from bson import ObjectId from ..models.db_model import DbModel +from ..models.id_model import Id +from ..models.responses import DbResponse from ..models.query_operators import QueryOperator from ..engine.utils import consolidate_dict, mount_base_pipeline @@ -61,10 +63,21 @@ def _after_save( filter(lambda x: x._collection == collection_name, objs) ) for index, obj_id in result.upserted_ids.items(): - objs_from_collection[index].id = obj_id + objs_from_collection[index].id = Id(obj_id) objs_from_collection[index].created_at = now objs_from_collection[index].updated_at = now + def _db_response(self, result: BulkWriteResult): + return DbResponse( + acknowledged=result.acknowledged, + deleted_count=result.deleted_count, + inserted_count=result.inserted_count, + matched_count=result.matched_count, + modified_count=result.modified_count, + upserted_count=result.upserted_count, + upserted_ids=result.upserted_ids, + ) + class AsyncDbEngine(_Engine): def __init__(self, mongo_uri, db_name, tz_info: timezone = None): @@ -75,9 +88,9 @@ def __init__(self, mongo_uri, db_name, tz_info: timezone = None): tz_info=tz_info, ) - async def save_all(self, objs: list[DbModel]): + async def save_all(self, obj_list: list[DbModel]): indexes, operations, now = self._create_operations_list( - objs=objs, query=None, raw_query=None + objs=obj_list, query=None, raw_query=None ) for collection_name, index_list in indexes.items(): await self._db[collection_name].create_indexes(index_list) @@ -86,9 +99,27 @@ async def save_all(self, objs: list[DbModel]): operation_list ) self._after_save( - result=result, objs=objs, collection_name=collection_name, now=now + result=result, objs=obj_list, collection_name=collection_name, now=now ) + async def save( + self, obj: DbModel, query: QueryOperator = None, raw_query: dict = None + ): + indexes, operations, now = self._create_operations_list( + objs=[obj], query=query, raw_query=raw_query + ) + collection_name = obj._collection + index_list = indexes[collection_name] + await self._db[collection_name].create_indexes(index_list) + operation_list = operations[collection_name] + result: BulkWriteResult = await self._db[collection_name].bulk_write( + operation_list + ) + self._after_save( + result=result, objs=[obj], collection_name=collection_name, now=now + ) + return self._db_response(result=result) + class DbEngine(_Engine): def __init__(self, mongo_uri, db_name, tz_info: timezone = None): @@ -99,9 +130,9 @@ def __init__(self, mongo_uri, db_name, tz_info: timezone = None): tz_info=tz_info, ) - def save_all(self, objs: list[DbModel]): + def save_all(self, obj_list: list[DbModel]): indexes, operations, now = self._create_operations_list( - objs=objs, query=None, raw_query=None + objs=obj_list, query=None, raw_query=None ) for collection_name, index_list in indexes.items(): self._db[collection_name].create_indexes(index_list) @@ -110,5 +141,19 @@ def save_all(self, objs: list[DbModel]): operation_list ) self._after_save( - result=result, objs=objs, collection_name=collection_name, now=now + result=result, objs=obj_list, collection_name=collection_name, now=now ) + + def save(self, obj: DbModel, query: QueryOperator = None, raw_query: dict = None): + indexes, operations, now = self._create_operations_list( + objs=[obj], query=query, raw_query=raw_query + ) + collection_name = obj._collection + index_list = indexes[collection_name] + self._db[collection_name].create_indexes(index_list) + operation_list = operations[collection_name] + result: BulkWriteResult = self._db[collection_name].bulk_write(operation_list) + self._after_save( + result=result, objs=[obj], collection_name=collection_name, now=now + ) + return self._db_response(result=result) diff --git a/pyodmongo/models/responses.py b/pyodmongo/models/responses.py index f6bf722..2bc4f58 100644 --- a/pyodmongo/models/responses.py +++ b/pyodmongo/models/responses.py @@ -2,6 +2,16 @@ from .id_model import Id +class DbResponse(BaseModel): + acknowledged: bool + deleted_count: int + inserted_count: int + matched_count: int + modified_count: int + upserted_count: int + upserted_ids: dict[int, Id] + + class SaveResponse(BaseModel): """ Represents the response from a save operation (insert or update) in PyODMongo. This diff --git a/tests/test_engines.py b/tests/test_engines.py index 6190581..93907e1 100644 --- a/tests/test_engines.py +++ b/tests/test_engines.py @@ -2,8 +2,9 @@ from typing import ClassVar import pytest import pytest_asyncio -from pyodmongo import AsyncDbEngine, DbEngine, DbModel, Field +from pyodmongo import AsyncDbEngine, DbEngine, DbModel, Field, DbResponse from bson import ObjectId +import copy mongo_uri = "mongodb://localhost:27017" db_name = "pyodmongo_pytest" @@ -14,7 +15,7 @@ @pytest_asyncio.fixture() -async def drop_collection(): +async def drop_db(): await async_engine._client.drop_database(db_name) engine._client.drop_database(db_name) yield @@ -23,9 +24,7 @@ async def drop_collection(): @pytest.mark.asyncio -async def test_save_all_upsert(drop_collection): - print() - +async def test_save_all(drop_db): class MyClass0(DbModel): attr_0: str = Field(index=True) attr_1: int = Field(index=True) @@ -48,3 +47,43 @@ class MyClass1(DbModel): assert ObjectId.is_valid(obj_1.id) assert ObjectId.is_valid(obj_2.id) assert ObjectId.is_valid(obj_3.id) + + id_0 = copy.copy(obj_0.id) + id_1 = copy.copy(obj_1.id) + id_2 = copy.copy(obj_2.id) + id_3 = copy.copy(obj_3.id) + + obj_0.attr_0 = "zero_zero" + obj_1.attr_0 = "one_one" + obj_2.attr_2 = "two_two" + obj_3.attr_2 = "three_three" + + await async_engine.save_all([obj_1, obj_3]) + engine.save_all([obj_0, obj_2]) + + assert obj_0.id == id_0 + assert obj_1.id == id_1 + assert obj_2.id == id_2 + assert obj_3.id == id_3 + + +@pytest.mark.asyncio +async def test_save(drop_db): + class MyClass0(DbModel): + attr_0: str = Field(index=True) + attr_1: int = Field(index=True) + _collection: ClassVar = "my_class_0" + + class MyClass1(DbModel): + attr_2: str = Field(index=True) + attr_3: int = Field(index=True) + _collection: ClassVar = "my_class_1" + + obj_0 = MyClass0(attr_0="zero", attr_1=0) + obj_1 = MyClass1(attr_2="two", attr_3=2) + + response_0: DbResponse = await async_engine.save(obj_0) + response_1: DbResponse = engine.save(obj_1) + + assert obj_0.id == response_0.upserted_ids[0] + assert obj_1.id == response_1.upserted_ids[0]