Skip to content

Commit

Permalink
♻️ save
Browse files Browse the repository at this point in the history
  • Loading branch information
mauro-andre committed May 24, 2024
1 parent 9d29588 commit f80d11a
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyodmongo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from .models.db_model import DbModel, MainBaseModel
from .models.id_model import Id
from .models.paginate import ResponsePaginate
from .models.responses import SaveResponse, DeleteResponse
from .models.responses import DbResponse, SaveResponse, DeleteResponse
from .models.fields import Field
59 changes: 52 additions & 7 deletions pyodmongo/engines/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import datetime, timezone, UTC
from bson import ObjectId
from ..models.db_model import DbModel
from ..models.id_model import Id
from ..models.responses import DbResponse
from ..models.query_operators import QueryOperator
from ..engine.utils import consolidate_dict, mount_base_pipeline

Expand Down Expand Up @@ -61,10 +63,21 @@ def _after_save(
filter(lambda x: x._collection == collection_name, objs)
)
for index, obj_id in result.upserted_ids.items():
objs_from_collection[index].id = obj_id
objs_from_collection[index].id = Id(obj_id)
objs_from_collection[index].created_at = now
objs_from_collection[index].updated_at = now

def _db_response(self, result: BulkWriteResult):
return DbResponse(
acknowledged=result.acknowledged,
deleted_count=result.deleted_count,
inserted_count=result.inserted_count,
matched_count=result.matched_count,
modified_count=result.modified_count,
upserted_count=result.upserted_count,
upserted_ids=result.upserted_ids,
)


class AsyncDbEngine(_Engine):
def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
Expand All @@ -75,9 +88,9 @@ def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
tz_info=tz_info,
)

async def save_all(self, objs: list[DbModel]):
async def save_all(self, obj_list: list[DbModel]):
indexes, operations, now = self._create_operations_list(
objs=objs, query=None, raw_query=None
objs=obj_list, query=None, raw_query=None
)
for collection_name, index_list in indexes.items():
await self._db[collection_name].create_indexes(index_list)
Expand All @@ -86,9 +99,27 @@ async def save_all(self, objs: list[DbModel]):
operation_list
)
self._after_save(
result=result, objs=objs, collection_name=collection_name, now=now
result=result, objs=obj_list, collection_name=collection_name, now=now
)

async def save(
self, obj: DbModel, query: QueryOperator = None, raw_query: dict = None
):
indexes, operations, now = self._create_operations_list(
objs=[obj], query=query, raw_query=raw_query
)
collection_name = obj._collection
index_list = indexes[collection_name]
await self._db[collection_name].create_indexes(index_list)
operation_list = operations[collection_name]
result: BulkWriteResult = await self._db[collection_name].bulk_write(
operation_list
)
self._after_save(
result=result, objs=[obj], collection_name=collection_name, now=now
)
return self._db_response(result=result)


class DbEngine(_Engine):
def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
Expand All @@ -99,9 +130,9 @@ def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
tz_info=tz_info,
)

def save_all(self, objs: list[DbModel]):
def save_all(self, obj_list: list[DbModel]):
indexes, operations, now = self._create_operations_list(
objs=objs, query=None, raw_query=None
objs=obj_list, query=None, raw_query=None
)
for collection_name, index_list in indexes.items():
self._db[collection_name].create_indexes(index_list)
Expand All @@ -110,5 +141,19 @@ def save_all(self, objs: list[DbModel]):
operation_list
)
self._after_save(
result=result, objs=objs, collection_name=collection_name, now=now
result=result, objs=obj_list, collection_name=collection_name, now=now
)

def save(self, obj: DbModel, query: QueryOperator = None, raw_query: dict = None):
indexes, operations, now = self._create_operations_list(
objs=[obj], query=query, raw_query=raw_query
)
collection_name = obj._collection
index_list = indexes[collection_name]
self._db[collection_name].create_indexes(index_list)
operation_list = operations[collection_name]
result: BulkWriteResult = self._db[collection_name].bulk_write(operation_list)
self._after_save(
result=result, objs=[obj], collection_name=collection_name, now=now
)
return self._db_response(result=result)
10 changes: 10 additions & 0 deletions pyodmongo/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
from .id_model import Id


class DbResponse(BaseModel):
acknowledged: bool
deleted_count: int
inserted_count: int
matched_count: int
modified_count: int
upserted_count: int
upserted_ids: dict[int, Id]


class SaveResponse(BaseModel):
"""
Represents the response from a save operation (insert or update) in PyODMongo. This
Expand Down
49 changes: 44 additions & 5 deletions tests/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import ClassVar
import pytest
import pytest_asyncio
from pyodmongo import AsyncDbEngine, DbEngine, DbModel, Field
from pyodmongo import AsyncDbEngine, DbEngine, DbModel, Field, DbResponse
from bson import ObjectId
import copy

mongo_uri = "mongodb://localhost:27017"
db_name = "pyodmongo_pytest"
Expand All @@ -14,7 +15,7 @@


@pytest_asyncio.fixture()
async def drop_collection():
async def drop_db():
await async_engine._client.drop_database(db_name)
engine._client.drop_database(db_name)
yield
Expand All @@ -23,9 +24,7 @@ async def drop_collection():


@pytest.mark.asyncio
async def test_save_all_upsert(drop_collection):
print()

async def test_save_all(drop_db):
class MyClass0(DbModel):
attr_0: str = Field(index=True)
attr_1: int = Field(index=True)
Expand All @@ -48,3 +47,43 @@ class MyClass1(DbModel):
assert ObjectId.is_valid(obj_1.id)
assert ObjectId.is_valid(obj_2.id)
assert ObjectId.is_valid(obj_3.id)

id_0 = copy.copy(obj_0.id)
id_1 = copy.copy(obj_1.id)
id_2 = copy.copy(obj_2.id)
id_3 = copy.copy(obj_3.id)

obj_0.attr_0 = "zero_zero"
obj_1.attr_0 = "one_one"
obj_2.attr_2 = "two_two"
obj_3.attr_2 = "three_three"

await async_engine.save_all([obj_1, obj_3])
engine.save_all([obj_0, obj_2])

assert obj_0.id == id_0
assert obj_1.id == id_1
assert obj_2.id == id_2
assert obj_3.id == id_3


@pytest.mark.asyncio
async def test_save(drop_db):
class MyClass0(DbModel):
attr_0: str = Field(index=True)
attr_1: int = Field(index=True)
_collection: ClassVar = "my_class_0"

class MyClass1(DbModel):
attr_2: str = Field(index=True)
attr_3: int = Field(index=True)
_collection: ClassVar = "my_class_1"

obj_0 = MyClass0(attr_0="zero", attr_1=0)
obj_1 = MyClass1(attr_2="two", attr_3=2)

response_0: DbResponse = await async_engine.save(obj_0)
response_1: DbResponse = engine.save(obj_1)

assert obj_0.id == response_0.upserted_ids[0]
assert obj_1.id == response_1.upserted_ids[0]

0 comments on commit f80d11a

Please sign in to comment.