diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index afab4f7..7e66853 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -78,6 +78,7 @@ def _sources_to_documents(sources: list[UploadFile]) -> list[Document]: 'title': source.headers.get('title'), 'type': source.headers.get('type'), 'modified': source.headers.get('modified'), + 'provider': source.headers.get('provider'), } document = Document(page_content=content, metadata=metadata) @@ -158,7 +159,7 @@ def embed_sources( # either not a file or a file that is allowed sources_filtered = [ source for source in sources - if not source.filename.startswith('file: ') + if (source.filename is not None and not source.filename.startswith('file: ')) or _allowed_file(source) ] diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 5079d6d..63df74c 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -57,20 +57,21 @@ def _(userId: str): # TODO: for testing, remove later @app.get('/search') @enabled_guard(app) -def _(userId: str, keyword: str): - from chromadb import ClientAPI - from .vectordb import COLLECTION_NAME +def _(userId: str, sourceNames: str): + sourceNames: list[str] = [source.strip() for source in sourceNames.split(',') if source.strip() != ''] + + if len(sourceNames) == 0: + return JSONResponse('No sources provided', 400) db: BaseVectorDB = app.extra.get('VECTOR_DB') - client: ClientAPI = db.client - db.setup_schema(userId) - return JSONResponse( - client.get_collection(COLLECTION_NAME(userId)).get( - where_document={'$contains': [{'source': keyword}]}, - include=['metadatas'], - ) - ) + if db is None: + return JSONResponse('Error: VectorDB not initialised', 500) + + source_objs = db.get_objects_from_metadata(userId, 'source', sourceNames) + sources = list(map(lambda s: s.get('id'), source_objs.values())) + + return JSONResponse({ 'sources': sources }) @app.put('/enabled') @@ -110,19 +111,7 @@ def _(userId: Annotated[str, Body()], sourceNames: Annotated[list[str], Body()]) if db is None: return JSONResponse('Error: VectorDB not initialised', 500) - source_objs = db.get_objects_from_metadata(userId, 'source', sourceNames) - res = db.delete_by_ids(userId, [ - source.get('id') - for source in source_objs.values() - if value_of(source.get('id') is not None) - ]) - - # NOTE: None returned in `delete_by_ids` should have meant an error but it didn't in the case of - # weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take - # class name as input, which will be required in future versions of weaviate) - if res is None: - print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \ -successful, therefore not considered an error for now.') + res = db.delete(userId, 'source', sourceNames) if res is False: return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400) @@ -130,27 +119,18 @@ def _(userId: Annotated[str, Body()], sourceNames: Annotated[list[str], Body()]) return JSONResponse('All valid sources deleted') -@app.post('/deleteMatchingSources') +@app.post('/deleteSourcesByProvider') @enabled_guard(app) -def _(userId: Annotated[str, Body()], keyword: Annotated[str, Body()]): +def _(userId: Annotated[str, Body()], providerKey: Annotated[str, Body()]): + if value_of(providerKey) is None: + return JSONResponse('Invalid provider key provided', 400) + db: BaseVectorDB = app.extra.get('VECTOR_DB') if db is None: return JSONResponse('Error: VectorDB not initialised', 500) - objs = db.get_objects_from_metadata(userId, 'source', [keyword], True) - res = db.delete_by_ids(userId, [ - obj.get('id') - for obj in objs.values() - if value_of(obj.get('id') is not None) - ]) - - # NOTE: None returned in `delete_by_ids` should have meant an error but it didn't in the case of - # weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take - # class name as input, which will be required in future versions of weaviate) - if res is None: - print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \ -successful, therefore not considered an error for now.') + res = db.delete(userId, 'provider', [providerKey]) if res is False: return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400) @@ -169,6 +149,7 @@ def _(sources: list[UploadFile]): value_of(source.headers.get('userId')) and value_of(source.headers.get('type')) and value_of(source.headers.get('modified')) + and value_of(source.headers.get('provider')) for source in sources] ): return JSONResponse('Invaild/missing headers', 400) diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py index dcf70c2..7e16b0e 100644 --- a/context_chat_backend/vectordb/base.py +++ b/context_chat_backend/vectordb/base.py @@ -4,6 +4,8 @@ from langchain.schema.embeddings import Embeddings from langchain.vectorstores import VectorStore +from ..utils import value_of + class BaseVectorDB(ABC): client = None @@ -56,7 +58,6 @@ def get_objects_from_metadata( user_id: str, metadata_key: str, values: List[str], - contains: bool = False, ) -> dict: ''' Get all objects with the given metadata key and values. @@ -70,9 +71,6 @@ def get_objects_from_metadata( Metadata key to get. values: List[str] List of metadata names to get. - contains: bool - If True, gets all objects that contain any of the given values, - otherwise gets all objects that have the given values. Returns ------- @@ -89,7 +87,7 @@ def get_objects_from_metadata( } ''' - def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]: + def delete_by_ids(self, user_id: str, ids: list[str]) -> bool: ''' Deletes all documents with the given ids for the given user. @@ -102,9 +100,9 @@ def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]: Returns ------- - Optional[bool] - Optional[bool]: True if deletion is successful, - False otherwise, None if not implemented. + bool + True if deletion is successful, + False otherwise ''' if len(ids) == 0: return True @@ -113,4 +111,49 @@ def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]: if user_client is None: return False - return user_client.delete(ids) + res = user_client.delete(ids) + + # NOTE: None should have meant an error but it didn't in the case of + # weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take + # class name as input, which will be required in future versions of weaviate) + if res is None: + print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \ +successful, therefore not considered an error for now.') + return True + + return res + + def delete(self, user_id: str, metadata_key: str, values: list[str]) -> bool: + ''' + Deletes all documents with the matching values for the given metadata key. + + Args + ---- + user_id: str + User ID from whose database to delete the documents. + metadata_key: str + Metadata key to delete by. + values: list[str] + List of metadata values to match. + + Returns + ------- + bool + True if deletion is successful, + False otherwise + ''' + if len(values) == 0: + return True + + user_client = self.get_user_client(user_id) + if user_client is None: + return False + + objs = self.get_objects_from_metadata(user_id, metadata_key, values) + ids = [ + obj.get('id') + for obj in objs.values() + if value_of(obj.get('id') is not None) + ] + + return self.delete_by_ids(user_id, ids) diff --git a/context_chat_backend/vectordb/chroma.py b/context_chat_backend/vectordb/chroma.py index 0b00441..3c5e8e8 100644 --- a/context_chat_backend/vectordb/chroma.py +++ b/context_chat_backend/vectordb/chroma.py @@ -65,7 +65,6 @@ def get_objects_from_metadata( user_id: str, metadata_key: str, values: List[str], - contains: bool = False, ) -> dict: # NOTE: the limit of objects returned is not known, maybe it would be better to set one manually @@ -77,16 +76,7 @@ def get_objects_from_metadata( if len(values) == 0: return {} - if len(values) == 1: - if contains: - data_filter = { metadata_key: { '$in': values[0] } } - else: - data_filter = { metadata_key: values[0] } - else: - if contains: - data_filter = {'$or': [{ metadata_key: { '$in': val } } for val in values]} - else: - data_filter = {'$or': [{ metadata_key: val } for val in values]} + data_filter = { metadata_key: { '$in': values } } try: results = self.client.get_collection(COLLECTION_NAME(user_id)).get( diff --git a/context_chat_backend/vectordb/weaviate.py b/context_chat_backend/vectordb/weaviate.py index cc890a8..cf78a0f 100644 --- a/context_chat_backend/vectordb/weaviate.py +++ b/context_chat_backend/vectordb/weaviate.py @@ -51,6 +51,11 @@ 'description': 'Last modified time of the file', 'name': 'modified', }, + { + 'dataType': ['text'], + 'description': 'The provider of the source', + 'name': 'provider', + } ], # TODO: optimisation for large number of objects 'vectorIndexType': 'hnsw', @@ -126,31 +131,17 @@ def get_objects_from_metadata( user_id: str, metadata_key: str, values: List[str], - contains: bool = False, ) -> dict: # NOTE: the limit of objects returned is not known, maybe it would be better to set one manually if not self.client: raise Exception('Error: Weaviate client not initialised') - if not self.client.schema.exists(COLLECTION_NAME(user_id)): - self.setup_schema(user_id) + self.setup_schema(user_id) if len(values) == 0: return {} - # todo - if len(values) == 1: - if contains: - data_filter = { metadata_key: { '$in': values[0] } } - else: - data_filter = { metadata_key: values[0] } - else: - if contains: - data_filter = {'$or': [{ metadata_key: { '$in': val } } for val in values]} - else: - data_filter = {'$or': [{ metadata_key: val } for val in values]} - data_filter = { 'path': [metadata_key], 'operator': 'ContainsAny',