From 79e523c3950faebd7a9ff09a0e6e1488b670608b Mon Sep 17 00:00:00 2001 From: Avram Tudor Date: Mon, 17 Feb 2025 13:54:31 +0200 Subject: [PATCH] feat: use redis pubsub to sync s3 bucket changes across nodes (#150) * feat: use redis pubsub to sync s3 bucket changes across nodes * replace destructor with a cleanup fn --------- Co-authored-by: Avram Tudor --- run-skynet-requests.py | 2 +- skynet/main.py | 5 +++++ skynet/modules/ttt/assistant/app.py | 11 +++++++++- skynet/modules/ttt/processor.py | 2 ++ skynet/modules/ttt/rag/stores/faiss.py | 14 +++++++++---- skynet/modules/ttt/rag/stores/s3.py | 28 +++++++++++++++++++++++--- skynet/modules/ttt/rag/vector_store.py | 12 +++++++++-- skynet/modules/ttt/s3.py | 6 +++--- 8 files changed, 66 insertions(+), 14 deletions(-) diff --git a/run-skynet-requests.py b/run-skynet-requests.py index a4e11756..7ca8779e 100644 --- a/run-skynet-requests.py +++ b/run-skynet-requests.py @@ -83,7 +83,7 @@ async def get(job_id): duration = response['duration'] total_duration += duration - print(f'Job {job_id} status: {status} duration: {duration} \n') + print(f'Job {job_id} status: {status} duration: {duration} \n {response["result"]} \n') if status != 'success': success = False diff --git a/skynet/main.py b/skynet/main.py index 8656cc60..9588976d 100644 --- a/skynet/main.py +++ b/skynet/main.py @@ -68,6 +68,11 @@ async def lifespan(main_app: FastAPI): await executor_shutdown() + if 'assistant' in modules: + from skynet.modules.ttt.assistant.app import app_shutdown as assistant_shutdown + + await assistant_shutdown() + await http_client.close() diff --git a/skynet/modules/ttt/assistant/app.py b/skynet/modules/ttt/assistant/app.py index f62f7861..7b906229 100644 --- a/skynet/modules/ttt/assistant/app.py +++ b/skynet/modules/ttt/assistant/app.py @@ -25,4 +25,13 @@ async def app_startup(): log.info('assistant module initialized') -__all__ = ['app', 'app_startup'] +async def app_shutdown(): + await db.close() + log.info('Persistence closed') + + vector_store = await get_vector_store() + await vector_store.cleanup() + log.info('vector store cleaned up') + + +__all__ = ['app', 'app_startup', 'app_shutdown'] diff --git a/skynet/modules/ttt/processor.py b/skynet/modules/ttt/processor.py index e899523f..797f4eb0 100644 --- a/skynet/modules/ttt/processor.py +++ b/skynet/modules/ttt/processor.py @@ -65,6 +65,8 @@ def format_docs(docs: list[Document]) -> str: for doc in docs: log.debug(doc.metadata.get('source')) + log.info(f'Using {len(docs)} documents for RAG') + return '\n\n'.join(doc.page_content for doc in docs) diff --git a/skynet/modules/ttt/rag/stores/faiss.py b/skynet/modules/ttt/rag/stores/faiss.py index e1fd0910..325e4603 100644 --- a/skynet/modules/ttt/rag/stores/faiss.py +++ b/skynet/modules/ttt/rag/stores/faiss.py @@ -1,4 +1,3 @@ -import asyncio import shutil import time @@ -30,11 +29,18 @@ def __init__(self): def get_vector_store_path(self, store_id): return f'{vector_store_path}/faiss/{store_id}' + async def cleanup(self): + await super().cleanup() + + if self.s3: + await self.s3.cleanup() + async def initialize(self): await super().initialize() - if use_s3: + if self.s3: await self.s3.replicate(self.get_vector_store_path) + await self.s3.listen() async def get(self, store_id): try: @@ -62,7 +68,7 @@ async def create(self, store_id, documents): end = time.perf_counter_ns() duration = round((end - start) / 1e9) - if use_s3: + if self.s3: await self.s3.upload(self.get_vector_store_path(store_id)) log.info(f'Saving vector store took {duration} seconds') @@ -75,7 +81,7 @@ async def delete(self, store_id): path = self.get_vector_store_path(store_id) shutil.rmtree(path, ignore_errors=True) - if use_s3: + if self.s3: await self.s3.delete(path) diff --git a/skynet/modules/ttt/rag/stores/s3.py b/skynet/modules/ttt/rag/stores/s3.py index 17576d05..621ff4ad 100644 --- a/skynet/modules/ttt/rag/stores/s3.py +++ b/skynet/modules/ttt/rag/stores/s3.py @@ -1,6 +1,8 @@ +import asyncio +import json import os -from skynet.env import vector_store_type +from skynet.env import app_uuid, vector_store_type from skynet.logs import get_logger from skynet.modules.ttt.persistence import db from skynet.modules.ttt.rag.constants import STORED_RAG_KEY @@ -21,6 +23,16 @@ async def files_aiter(): class RagS3: def __init__(self): self.s3 = S3() + self.listen_task = None + + async def cleanup(self): + if self.listen_task: + self.listen_task.cancel() + + async def listen(self): + pubsub = db.db.pubsub() + await pubsub.subscribe(**{'s3-upload': self.handleS3Upload}) + self.listen_task = asyncio.create_task(pubsub.run()) async def replicate(self, prefix_function: callable): stored_keys = await db.lrange(STORED_RAG_KEY, 0, -1) @@ -33,12 +45,22 @@ async def replicate(self, prefix_function: callable): await self.s3.download_file(f'{folder}/{filename}') async def upload(self, folder): - async for filename in files_aiter(): - await self.s3.upload_file(f'{folder}/{filename}') + async for name in files_aiter(): + filename = f'{folder}/{name}' + await self.s3.upload_file(filename) + await db.db.publish('s3-upload', json.dumps({'filename': filename, 'app_uuid': app_uuid})) async def delete(self, folder): async for filename in files_aiter(): await self.s3.delete_file(f'{folder}/{filename}') + async def handleS3Upload(self, message): + message = json.loads(message['data']) + filename = message.get('filename') + uuid = message.get('app_uuid') + + if uuid != app_uuid: + await self.s3.download_file(filename) + __all__ = ['RagS3'] diff --git a/skynet/modules/ttt/rag/vector_store.py b/skynet/modules/ttt/rag/vector_store.py index 4f1c7efc..12fbb17a 100644 --- a/skynet/modules/ttt/rag/vector_store.py +++ b/skynet/modules/ttt/rag/vector_store.py @@ -30,6 +30,13 @@ async def initialize(self): """ pass + @abstractmethod + async def cleanup(self): + """ + Clean up the vector store. + """ + pass + @abstractmethod def get_vector_store_path(self, store_id: str): """ @@ -94,15 +101,16 @@ async def workflow(self, payload: RagPayload, store_id: str): documents = await crawl(payload) await self.create(store_id, documents) - await db.lrem(RUNNING_RAG_KEY, 0, store_id) + await db.lrem(STORED_RAG_KEY, 0, store_id) # ensure no duplicates await db.rpush(STORED_RAG_KEY, store_id) await self.update_config(store_id, status=RagStatus.SUCCESS) except Exception as e: - await db.lrem(RUNNING_RAG_KEY, 0, store_id) await db.rpush(ERROR_RAG_KEY, store_id) await self.update_config(store_id, status=RagStatus.ERROR, error=str(e)) log.error(e) + await db.lrem(RUNNING_RAG_KEY, 0, store_id) + async def create_from_urls(self, payload: RagPayload, store_id: str) -> Optional[RagConfig]: """ Create a vector store with the given id, using the documents crawled from the given URL. diff --git a/skynet/modules/ttt/s3.py b/skynet/modules/ttt/s3.py index 7f0ec5c0..767fe220 100644 --- a/skynet/modules/ttt/s3.py +++ b/skynet/modules/ttt/s3.py @@ -29,7 +29,7 @@ async def download_file(self, filename): await obj.download_fileobj(data) log.info(f'Downloaded file from S3: {filename}') except Exception as e: - log.error(f'Failed to download file from S3: {e}') + log.error(f'Failed to download file {filename} from S3: {e}') async def upload_file(self, filename): try: @@ -40,7 +40,7 @@ async def upload_file(self, filename): await bucket.upload_fileobj(data, filename) log.info(f'Uploaded file to S3: {filename}') except Exception as e: - log.error(f'Failed to upload file to S3: {e}') + log.error(f'Failed to upload file {filename} to S3: {e}') async def delete_file(self, filename): try: @@ -49,7 +49,7 @@ async def delete_file(self, filename): await obj.delete() log.info(f'Deleted file from S3: {filename}') except Exception as e: - log.error(f'Failed to delete file from S3: {e}') + log.error(f'Failed to delete file {filename} from S3: {e}') __all__ = ['S3']