Skip to content

Commit

Permalink
feat: use redis pubsub to sync s3 bucket changes across nodes (#150)
Browse files Browse the repository at this point in the history
* feat: use redis pubsub to sync s3 bucket changes across nodes

* replace destructor with a cleanup fn

---------

Co-authored-by: Avram Tudor <[email protected]>
  • Loading branch information
quitrk and Avram Tudor authored Feb 17, 2025
1 parent b25032f commit 79e523c
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 14 deletions.
2 changes: 1 addition & 1 deletion run-skynet-requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def get(job_id):
duration = response['duration']
total_duration += duration

print(f'Job {job_id} status: {status} duration: {duration} \n')
print(f'Job {job_id} status: {status} duration: {duration} \n {response["result"]} \n')

if status != 'success':
success = False
Expand Down
5 changes: 5 additions & 0 deletions skynet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ async def lifespan(main_app: FastAPI):

await executor_shutdown()

if 'assistant' in modules:
from skynet.modules.ttt.assistant.app import app_shutdown as assistant_shutdown

await assistant_shutdown()

await http_client.close()


Expand Down
11 changes: 10 additions & 1 deletion skynet/modules/ttt/assistant/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@ async def app_startup():
log.info('assistant module initialized')


__all__ = ['app', 'app_startup']
async def app_shutdown():
await db.close()
log.info('Persistence closed')

vector_store = await get_vector_store()
await vector_store.cleanup()
log.info('vector store cleaned up')


__all__ = ['app', 'app_startup', 'app_shutdown']
2 changes: 2 additions & 0 deletions skynet/modules/ttt/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def format_docs(docs: list[Document]) -> str:
for doc in docs:
log.debug(doc.metadata.get('source'))

log.info(f'Using {len(docs)} documents for RAG')

return '\n\n'.join(doc.page_content for doc in docs)


Expand Down
14 changes: 10 additions & 4 deletions skynet/modules/ttt/rag/stores/faiss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import shutil
import time

Expand Down Expand Up @@ -30,11 +29,18 @@ def __init__(self):
def get_vector_store_path(self, store_id):
return f'{vector_store_path}/faiss/{store_id}'

async def cleanup(self):
await super().cleanup()

if self.s3:
await self.s3.cleanup()

async def initialize(self):
await super().initialize()

if use_s3:
if self.s3:
await self.s3.replicate(self.get_vector_store_path)
await self.s3.listen()

async def get(self, store_id):
try:
Expand Down Expand Up @@ -62,7 +68,7 @@ async def create(self, store_id, documents):
end = time.perf_counter_ns()
duration = round((end - start) / 1e9)

if use_s3:
if self.s3:
await self.s3.upload(self.get_vector_store_path(store_id))

log.info(f'Saving vector store took {duration} seconds')
Expand All @@ -75,7 +81,7 @@ async def delete(self, store_id):
path = self.get_vector_store_path(store_id)
shutil.rmtree(path, ignore_errors=True)

if use_s3:
if self.s3:
await self.s3.delete(path)


Expand Down
28 changes: 25 additions & 3 deletions skynet/modules/ttt/rag/stores/s3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import json
import os

from skynet.env import vector_store_type
from skynet.env import app_uuid, vector_store_type
from skynet.logs import get_logger
from skynet.modules.ttt.persistence import db
from skynet.modules.ttt.rag.constants import STORED_RAG_KEY
Expand All @@ -21,6 +23,16 @@ async def files_aiter():
class RagS3:
def __init__(self):
self.s3 = S3()
self.listen_task = None

async def cleanup(self):
if self.listen_task:
self.listen_task.cancel()

async def listen(self):
pubsub = db.db.pubsub()
await pubsub.subscribe(**{'s3-upload': self.handleS3Upload})
self.listen_task = asyncio.create_task(pubsub.run())

async def replicate(self, prefix_function: callable):
stored_keys = await db.lrange(STORED_RAG_KEY, 0, -1)
Expand All @@ -33,12 +45,22 @@ async def replicate(self, prefix_function: callable):
await self.s3.download_file(f'{folder}/{filename}')

async def upload(self, folder):
async for filename in files_aiter():
await self.s3.upload_file(f'{folder}/{filename}')
async for name in files_aiter():
filename = f'{folder}/{name}'
await self.s3.upload_file(filename)
await db.db.publish('s3-upload', json.dumps({'filename': filename, 'app_uuid': app_uuid}))

async def delete(self, folder):
async for filename in files_aiter():
await self.s3.delete_file(f'{folder}/{filename}')

async def handleS3Upload(self, message):
message = json.loads(message['data'])
filename = message.get('filename')
uuid = message.get('app_uuid')

if uuid != app_uuid:
await self.s3.download_file(filename)


__all__ = ['RagS3']
12 changes: 10 additions & 2 deletions skynet/modules/ttt/rag/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ async def initialize(self):
"""
pass

@abstractmethod
async def cleanup(self):
"""
Clean up the vector store.
"""
pass

@abstractmethod
def get_vector_store_path(self, store_id: str):
"""
Expand Down Expand Up @@ -94,15 +101,16 @@ async def workflow(self, payload: RagPayload, store_id: str):
documents = await crawl(payload)

await self.create(store_id, documents)
await db.lrem(RUNNING_RAG_KEY, 0, store_id)
await db.lrem(STORED_RAG_KEY, 0, store_id) # ensure no duplicates
await db.rpush(STORED_RAG_KEY, store_id)
await self.update_config(store_id, status=RagStatus.SUCCESS)
except Exception as e:
await db.lrem(RUNNING_RAG_KEY, 0, store_id)
await db.rpush(ERROR_RAG_KEY, store_id)
await self.update_config(store_id, status=RagStatus.ERROR, error=str(e))
log.error(e)

await db.lrem(RUNNING_RAG_KEY, 0, store_id)

async def create_from_urls(self, payload: RagPayload, store_id: str) -> Optional[RagConfig]:
"""
Create a vector store with the given id, using the documents crawled from the given URL.
Expand Down
6 changes: 3 additions & 3 deletions skynet/modules/ttt/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def download_file(self, filename):
await obj.download_fileobj(data)
log.info(f'Downloaded file from S3: {filename}')
except Exception as e:
log.error(f'Failed to download file from S3: {e}')
log.error(f'Failed to download file {filename} from S3: {e}')

async def upload_file(self, filename):
try:
Expand All @@ -40,7 +40,7 @@ async def upload_file(self, filename):
await bucket.upload_fileobj(data, filename)
log.info(f'Uploaded file to S3: {filename}')
except Exception as e:
log.error(f'Failed to upload file to S3: {e}')
log.error(f'Failed to upload file {filename} to S3: {e}')

async def delete_file(self, filename):
try:
Expand All @@ -49,7 +49,7 @@ async def delete_file(self, filename):
await obj.delete()
log.info(f'Deleted file from S3: {filename}')
except Exception as e:
log.error(f'Failed to delete file from S3: {e}')
log.error(f'Failed to delete file {filename} from S3: {e}')


__all__ = ['S3']

0 comments on commit 79e523c

Please sign in to comment.