Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: metadata search for provider #11

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion context_chat_backend/chain/ingest/injest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _sources_to_documents(sources: list[UploadFile]) -> list[Document]:
'title': source.headers.get('title'),
'type': source.headers.get('type'),
'modified': source.headers.get('modified'),
'provider': source.headers.get('provider'),
}

document = Document(page_content=content, metadata=metadata)
Expand Down Expand Up @@ -158,7 +159,7 @@ def embed_sources(
# either not a file or a file that is allowed
sources_filtered = [
source for source in sources
if not source.filename.startswith('file: ')
if (source.filename is not None and not source.filename.startswith('file: '))
or _allowed_file(source)
]

Expand Down
59 changes: 20 additions & 39 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,21 @@ def _(userId: str):
# TODO: for testing, remove later
@app.get('/search')
@enabled_guard(app)
def _(userId: str, keyword: str):
from chromadb import ClientAPI
from .vectordb import COLLECTION_NAME
def _(userId: str, sourceNames: str):
sourceNames: list[str] = [source.strip() for source in sourceNames.split(',') if source.strip() != '']

if len(sourceNames) == 0:
return JSONResponse('No sources provided', 400)

db: BaseVectorDB = app.extra.get('VECTOR_DB')
client: ClientAPI = db.client
db.setup_schema(userId)

return JSONResponse(
client.get_collection(COLLECTION_NAME(userId)).get(
where_document={'$contains': [{'source': keyword}]},
include=['metadatas'],
)
)
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

source_objs = db.get_objects_from_metadata(userId, 'source', sourceNames)
sources = list(map(lambda s: s.get('id'), source_objs.values()))

return JSONResponse({ 'sources': sources })


@app.put('/enabled')
Expand Down Expand Up @@ -110,47 +111,26 @@ def _(userId: Annotated[str, Body()], sourceNames: Annotated[list[str], Body()])
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

source_objs = db.get_objects_from_metadata(userId, 'source', sourceNames)
res = db.delete_by_ids(userId, [
source.get('id')
for source in source_objs.values()
if value_of(source.get('id') is not None)
])

# NOTE: None returned in `delete_by_ids` should have meant an error but it didn't in the case of
# weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take
# class name as input, which will be required in future versions of weaviate)
if res is None:
print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \
successful, therefore not considered an error for now.')
res = db.delete(userId, 'source', sourceNames)

if res is False:
return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)

return JSONResponse('All valid sources deleted')


@app.post('/deleteMatchingSources')
@app.post('/deleteSourcesByProvider')
@enabled_guard(app)
def _(userId: Annotated[str, Body()], keyword: Annotated[str, Body()]):
def _(userId: Annotated[str, Body()], providerKey: Annotated[str, Body()]):
if value_of(providerKey) is None:
return JSONResponse('Invalid provider key provided', 400)

db: BaseVectorDB = app.extra.get('VECTOR_DB')

if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

objs = db.get_objects_from_metadata(userId, 'source', [keyword], True)
res = db.delete_by_ids(userId, [
obj.get('id')
for obj in objs.values()
if value_of(obj.get('id') is not None)
])

# NOTE: None returned in `delete_by_ids` should have meant an error but it didn't in the case of
# weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take
# class name as input, which will be required in future versions of weaviate)
if res is None:
print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \
successful, therefore not considered an error for now.')
res = db.delete(userId, 'provider', [providerKey])

if res is False:
return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)
Expand All @@ -169,6 +149,7 @@ def _(sources: list[UploadFile]):
value_of(source.headers.get('userId'))
and value_of(source.headers.get('type'))
and value_of(source.headers.get('modified'))
and value_of(source.headers.get('provider'))
for source in sources]
):
return JSONResponse('Invaild/missing headers', 400)
Expand Down
61 changes: 52 additions & 9 deletions context_chat_backend/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores import VectorStore

from ..utils import value_of


class BaseVectorDB(ABC):
client = None
Expand Down Expand Up @@ -56,7 +58,6 @@ def get_objects_from_metadata(
user_id: str,
metadata_key: str,
values: List[str],
contains: bool = False,
) -> dict:
'''
Get all objects with the given metadata key and values.
Expand All @@ -70,9 +71,6 @@ def get_objects_from_metadata(
Metadata key to get.
values: List[str]
List of metadata names to get.
contains: bool
If True, gets all objects that contain any of the given values,
otherwise gets all objects that have the given values.

Returns
-------
Expand All @@ -89,7 +87,7 @@ def get_objects_from_metadata(
}
'''

def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]:
def delete_by_ids(self, user_id: str, ids: list[str]) -> bool:
'''
Deletes all documents with the given ids for the given user.

Expand All @@ -102,9 +100,9 @@ def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]:

Returns
-------
Optional[bool]
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
bool
True if deletion is successful,
False otherwise
'''
if len(ids) == 0:
return True
Expand All @@ -113,4 +111,49 @@ def delete_by_ids(self, user_id: str, ids: list[str]) -> Optional[bool]:
if user_client is None:
return False

return user_client.delete(ids)
res = user_client.delete(ids)

# NOTE: None should have meant an error but it didn't in the case of
# weaviate maybe because of the way weaviate wrapper is implemented (langchain's api does not take
# class name as input, which will be required in future versions of weaviate)
if res is None:
print('Deletion query returned "None". This can happen in Weaviate even if the deletion was \
successful, therefore not considered an error for now.')
return True

return res

def delete(self, user_id: str, metadata_key: str, values: list[str]) -> bool:
'''
Deletes all documents with the matching values for the given metadata key.

Args
----
user_id: str
User ID from whose database to delete the documents.
metadata_key: str
Metadata key to delete by.
values: list[str]
List of metadata values to match.

Returns
-------
bool
True if deletion is successful,
False otherwise
'''
if len(values) == 0:
return True

user_client = self.get_user_client(user_id)
if user_client is None:
return False

objs = self.get_objects_from_metadata(user_id, metadata_key, values)
ids = [
obj.get('id')
for obj in objs.values()
if value_of(obj.get('id') is not None)
]

return self.delete_by_ids(user_id, ids)
12 changes: 1 addition & 11 deletions context_chat_backend/vectordb/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def get_objects_from_metadata(
user_id: str,
metadata_key: str,
values: List[str],
contains: bool = False,
) -> dict:
# NOTE: the limit of objects returned is not known, maybe it would be better to set one manually

Expand All @@ -77,16 +76,7 @@ def get_objects_from_metadata(
if len(values) == 0:
return {}

if len(values) == 1:
if contains:
data_filter = { metadata_key: { '$in': values[0] } }
else:
data_filter = { metadata_key: values[0] }
else:
if contains:
data_filter = {'$or': [{ metadata_key: { '$in': val } } for val in values]}
else:
data_filter = {'$or': [{ metadata_key: val } for val in values]}
data_filter = { metadata_key: { '$in': values } }

try:
results = self.client.get_collection(COLLECTION_NAME(user_id)).get(
Expand Down
21 changes: 6 additions & 15 deletions context_chat_backend/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
'description': 'Last modified time of the file',
'name': 'modified',
},
{
'dataType': ['text'],
'description': 'The provider of the source',
'name': 'provider',
}
],
# TODO: optimisation for large number of objects
'vectorIndexType': 'hnsw',
Expand Down Expand Up @@ -126,31 +131,17 @@ def get_objects_from_metadata(
user_id: str,
metadata_key: str,
values: List[str],
contains: bool = False,
) -> dict:
# NOTE: the limit of objects returned is not known, maybe it would be better to set one manually

if not self.client:
raise Exception('Error: Weaviate client not initialised')

if not self.client.schema.exists(COLLECTION_NAME(user_id)):
self.setup_schema(user_id)
self.setup_schema(user_id)

if len(values) == 0:
return {}

# todo
if len(values) == 1:
if contains:
data_filter = { metadata_key: { '$in': values[0] } }
else:
data_filter = { metadata_key: values[0] }
else:
if contains:
data_filter = {'$or': [{ metadata_key: { '$in': val } } for val in values]}
else:
data_filter = {'$or': [{ metadata_key: val } for val in values]}

data_filter = {
'path': [metadata_key],
'operator': 'ContainsAny',
Expand Down
Loading