Skip to content

Commit

Permalink
exec store cleanup and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ritch committed Oct 2, 2024
1 parent d4d9683 commit 9f9182b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 43 deletions.
50 changes: 40 additions & 10 deletions fiftyone/factory/repos/execution_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Execution store repository.
"""

import datetime
from pymongo.collection import Collection
from fiftyone.operators.store.models import StoreDocument, KeyDocument

Expand Down Expand Up @@ -35,18 +36,36 @@ def list_stores(self) -> list[str]:
return self._collection.distinct("store_name")

def set_key(self, store_name, key, value, ttl=None) -> KeyDocument:
"""Sets or updates a key in the specified store."""
"""Sets or updates a key in the specified store"""
now = datetime.datetime.now()
expiration = KeyDocument.get_expiration(ttl)
key_doc = KeyDocument(
store_name=store_name,
key=key,
value=value,
expires_at=expiration if ttl else None,
store_name=store_name, key=key, value=value, updated_at=now
)
# Update or insert the key
self._collection.update_one(
_where(store_name, key), {"$set": key_doc.dict()}, upsert=True

# Prepare the update operations
update_fields = {
"$set": key_doc.dict(
exclude={"created_at", "expires_at", "store_name", "key"}
),
"$setOnInsert": {
"store_name": store_name,
"key": key,
"created_at": now,
"expires_at": expiration if ttl else None,
},
}

# Perform the upsert operation
result = self._collection.update_one(
_where(store_name, key), update_fields, upsert=True
)

if result.upserted_id:
key_doc.created_at = now
else:
key_doc.updated_at = now

return key_doc

def get_key(self, store_name, key) -> KeyDocument:
Expand All @@ -57,8 +76,7 @@ def get_key(self, store_name, key) -> KeyDocument:

def list_keys(self, store_name) -> list[str]:
"""Lists all keys in the specified store."""
keys = self._collection.find(_where(store_name))
# TODO: redact non-key fields
keys = self._collection.find(_where(store_name), {"key": 1})
return [key["key"] for key in keys]

def update_ttl(self, store_name, key, ttl) -> bool:
Expand Down Expand Up @@ -92,7 +110,19 @@ def __init__(self, collection: Collection):
def _create_indexes(self):
indices = self._collection.list_indexes()
expires_at_name = "expires_at"
store_name_name = "store_name"
key_name = "key"
full_key_name = "store_name_and_key"
if expires_at_name not in indices:
self._collection.create_index(
expires_at_name, name=expires_at_name, expireAfterSeconds=0
)
if full_key_name not in indices:
self._collection.create_index(
[(store_name_name, 1), (key_name, 1)],
name=full_key_name,
unique=True,
)
for name in [store_name_name, key_name]:
if name not in indices:
self._collection.create_index(name, name=name)
47 changes: 14 additions & 33 deletions tests/unittests/execution_store_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ def test_set_key(self):
{"store_name": "widgets", "key": "widget_1"},
{
"$set": {
"value": {"name": "Widget One", "value": 100},
"updated_at": IsDateTime(),
},
"$setOnInsert": {
"store_name": "widgets",
"key": "widget_1",
"value": {"name": "Widget One", "value": 100},
"created_at": IsDateTime(),
"updated_at": None,
"expires_at": IsDateTime(),
}
},
},
upsert=True,
)
Expand Down Expand Up @@ -140,7 +142,9 @@ def test_list_keys(self):
keys = self.store_repo.list_keys("widgets")
assert keys == ["widget_1", "widget_2"]
self.mock_collection.find.assert_called_once()
self.mock_collection.find.assert_called_with({"store_name": "widgets"})
self.mock_collection.find.assert_called_with(
{"store_name": "widgets"}, {"key": 1}
)

def test_list_stores(self):
self.mock_collection.distinct.return_value = ["widgets", "gadgets"]
Expand All @@ -166,42 +170,19 @@ def test_set(self):
{"store_name": "mock_store", "key": "widget_1"},
{
"$set": {
"updated_at": IsDateTime(),
"value": {"name": "Widget One", "value": 100},
},
"$setOnInsert": {
"store_name": "mock_store",
"key": "widget_1",
"value": {"name": "Widget One", "value": 100},
"created_at": IsDateTime(),
"updated_at": None,
"expires_at": IsDateTime(),
}
},
},
upsert=True,
)

# def test_update(self):
# self.mock_collection.find_one.return_value = {
# "store_name": "mock_store",
# "key": "widget_1",
# "value": {"name": "Widget One", "value": 100},
# "created_at": time.time(),
# "updated_at": time.time(),
# "expires_at": time.time() + 60000
# }
# self.store.update_key("widget_1", {"name": "Widget One", "value": 200})
# self.mock_collection.update_one.assert_called_once()
# self.mock_collection.update_one.assert_called_with(
# {"store_name": "mock_store", "key": "widget_1"},
# {
# "$set": {
# "store_name": "mock_store",
# "key": "widget_1",
# "value": {"name": "Widget One", "value": 200},
# "created_at": IsDateTime(),
# "updated_at": IsDateTime(),
# "expires_at": IsDateTime()
# }
# }
# )

def test_get(self):
self.mock_collection.find_one.return_value = {
"store_name": "mock_store",
Expand All @@ -227,7 +208,7 @@ def test_list_keys(self):
assert keys == ["widget_1", "widget_2"]
self.mock_collection.find.assert_called_once()
self.mock_collection.find.assert_called_with(
{"store_name": "mock_store"}
{"store_name": "mock_store"}, {"key": 1}
)

def test_delete(self):
Expand Down

0 comments on commit 9f9182b

Please sign in to comment.