Skip to content

Commit

Permalink
♻️ save_all
Browse files Browse the repository at this point in the history
  • Loading branch information
mauro-andre committed May 24, 2024
1 parent 3d850fa commit 9d29588
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pyodmongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .async_engine.engine import AsyncDbEngine
from .engine.engine import DbEngine
# from .async_engine.engine import AsyncDbEngine
# from .engine.engine import DbEngine
from .engines.engines import AsyncDbEngine, DbEngine
from .models.db_model import DbModel, MainBaseModel
from .models.id_model import Id
from .models.paginate import ResponsePaginate
Expand Down
114 changes: 114 additions & 0 deletions pyodmongo/engines/engines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient, UpdateMany
from pymongo.results import BulkWriteResult
from datetime import datetime, timezone, UTC
from bson import ObjectId
from ..models.db_model import DbModel
from ..models.query_operators import QueryOperator
from ..engine.utils import consolidate_dict, mount_base_pipeline


class _Engine:
def __init__(self, Client, mongo_uri, db_name, tz_info: timezone = None):
self._client = Client(mongo_uri)
self._db = self._client[db_name]
self._tz_info = tz_info

def _update_many_operation(self, obj: DbModel, query_dict: dict, now):
dct = consolidate_dict(obj=obj, dct={})
find_filter = query_dict or {"_id": ObjectId(dct.get("_id"))}
dct[obj.__class__.updated_at.field_alias] = now
dct.pop("_id")
dct.pop(obj.__class__.created_at.field_alias)
to_save = {
"$set": dct,
"$setOnInsert": {obj.__class__.created_at.field_alias: now},
}
return UpdateMany(filter=find_filter, update=to_save, upsert=True)

def _create_operations_list(
self,
objs: list[DbModel],
query: QueryOperator,
raw_query: dict,
):
operations = {}
indexes = {}
raw_query = {} if not raw_query else raw_query
query = query.to_dict() if query else raw_query
now = datetime.now(self._tz_info)
now = now.replace(microsecond=int(now.microsecond / 1000) * 1000)
for obj in objs:
obj: DbModel
operation = self._update_many_operation(obj=obj, query_dict=query, now=now)
collection_name = obj._collection
try:
operations[collection_name] += [operation]
except KeyError:
operations[collection_name] = [operation]

try:
obj_indexes = obj._indexes
except AttributeError:
obj_indexes = obj._init_indexes
indexes[collection_name] = obj_indexes
return indexes, operations, now

def _after_save(
self, result: BulkWriteResult, objs: list[DbModel], collection_name: str, now
):
objs_from_collection = list(
filter(lambda x: x._collection == collection_name, objs)
)
for index, obj_id in result.upserted_ids.items():
objs_from_collection[index].id = obj_id
objs_from_collection[index].created_at = now
objs_from_collection[index].updated_at = now


class AsyncDbEngine(_Engine):
def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
super().__init__(
Client=AsyncIOMotorClient,
mongo_uri=mongo_uri,
db_name=db_name,
tz_info=tz_info,
)

async def save_all(self, objs: list[DbModel]):
indexes, operations, now = self._create_operations_list(
objs=objs, query=None, raw_query=None
)
for collection_name, index_list in indexes.items():
await self._db[collection_name].create_indexes(index_list)
for collection_name, operation_list in operations.items():
result: BulkWriteResult = await self._db[collection_name].bulk_write(
operation_list
)
self._after_save(
result=result, objs=objs, collection_name=collection_name, now=now
)


class DbEngine(_Engine):
def __init__(self, mongo_uri, db_name, tz_info: timezone = None):
super().__init__(
Client=MongoClient,
mongo_uri=mongo_uri,
db_name=db_name,
tz_info=tz_info,
)

def save_all(self, objs: list[DbModel]):
indexes, operations, now = self._create_operations_list(
objs=objs, query=None, raw_query=None
)
for collection_name, index_list in indexes.items():
self._db[collection_name].create_indexes(index_list)
for collection_name, operation_list in operations.items():
result: BulkWriteResult = self._db[collection_name].bulk_write(
operation_list
)
self._after_save(
result=result, objs=objs, collection_name=collection_name, now=now
)
50 changes: 50 additions & 0 deletions tests/test_engines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from datetime import timezone, timedelta
from typing import ClassVar
import pytest
import pytest_asyncio
from pyodmongo import AsyncDbEngine, DbEngine, DbModel, Field
from bson import ObjectId

mongo_uri = "mongodb://localhost:27017"
db_name = "pyodmongo_pytest"
tz_info = timezone(timedelta(hours=-3))

async_engine = AsyncDbEngine(mongo_uri=mongo_uri, db_name=db_name, tz_info=tz_info)
engine = DbEngine(mongo_uri=mongo_uri, db_name=db_name, tz_info=tz_info)


@pytest_asyncio.fixture()
async def drop_collection():
await async_engine._client.drop_database(db_name)
engine._client.drop_database(db_name)
yield
# await async_engine._client.drop_database(db_name)
# engine._client.drop_database(db_name)


@pytest.mark.asyncio
async def test_save_all_upsert(drop_collection):
print()

class MyClass0(DbModel):
attr_0: str = Field(index=True)
attr_1: int = Field(index=True)
_collection: ClassVar = "my_class_0"

class MyClass1(DbModel):
attr_2: str = Field(index=True)
attr_3: int = Field(index=True)
_collection: ClassVar = "my_class_1"

obj_0 = MyClass0(attr_0="zero", attr_1=0)
obj_1 = MyClass0(attr_0="one", attr_1=1)
obj_2 = MyClass1(attr_2="two", attr_3=2)
obj_3 = MyClass1(attr_2="three", attr_3=3)

await async_engine.save_all([obj_0, obj_2])
engine.save_all([obj_1, obj_3])

assert ObjectId.is_valid(obj_0.id)
assert ObjectId.is_valid(obj_1.id)
assert ObjectId.is_valid(obj_2.id)
assert ObjectId.is_valid(obj_3.id)

0 comments on commit 9d29588

Please sign in to comment.