Skip to content

Commit

Permalink
Introduce /deleteSourcesByProviderForAllUsers and fixes
Browse files Browse the repository at this point in the history
- add get_users() in base, chroma and weaviate
- /vectors outputs vectors for all the users
- fix setup when DISABLE_CUSTOM_DOWNLOAD_URI is set

Signed-off-by: Anupam Kumar <[email protected]>
  • Loading branch information
kyteinsky committed Mar 1, 2024
1 parent bd3a952 commit 5c4dfd1
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 13 deletions.
36 changes: 28 additions & 8 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _(query: str | None = None):
# TODO: for testing, remove later
@app.get('/vectors')
@enabled_guard(app)
def _(userId: str):
def _():
from chromadb.api import ClientAPI

from .vectordb import COLLECTION_NAME
Expand All @@ -61,11 +61,12 @@ def _(userId: str):
if client is None:
return JSONResponse('Error: VectorDB client not initialised', 500)

db.setup_schema(userId)
vectors = {}
for user_id in db.get_users():
db.setup_schema(user_id)
vectors[user_id] = client.get_collection(COLLECTION_NAME(user_id)).get()

return JSONResponse(
client.get_collection(COLLECTION_NAME(userId)).get()
)
return JSONResponse(vectors)


# TODO: for testing, remove later
Expand Down Expand Up @@ -153,6 +154,25 @@ def _(userId: Annotated[str, Body()], providerKey: Annotated[str, Body()]):
return JSONResponse('All valid sources deleted')


@app.post('/deleteSourcesByProviderForAllUsers')
@enabled_guard(app)
def _(providerKey: str = Body(embed=True)):
if value_of(providerKey) is None:
return JSONResponse('Invalid provider key provided', 400)

db: BaseVectorDB | None = app.extra.get('VECTOR_DB')

if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

res = db.delete_for_all_users('provider', [providerKey])

if res is False:
return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)

return JSONResponse('All valid sources deleted')


@app.put('/loadSources')
@enabled_guard(app)
def _(sources: list[UploadFile]):
Expand Down Expand Up @@ -237,11 +257,11 @@ def check_empty_values(cls, value: Any, info: FieldValidationInfo):

@field_validator('ctxLimit')
@classmethod
def at_least_one_context(cls, v: int):
if v < 1:
def at_least_one_context(cls, value: int):
if value < 1:
raise ValueError('Invalid context chunk limit')

return v
return value

@app.post('/scopedQuery')
@enabled_guard(app)
Expand Down
9 changes: 6 additions & 3 deletions context_chat_backend/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ def download_all_models(app: FastAPI):

if os.getenv('DISABLE_CUSTOM_DOWNLOAD_URI', '0') == '1':
update_progress(100)
_set_app_config(app, config)
return

progress = 0
for model_type in ('embedding', 'llm'):
Expand All @@ -251,6 +249,12 @@ def model_init(app: FastAPI) -> bool:
global _MODELS_DIR
_MODELS_DIR = os.getenv('MODEL_DIR', 'persistent_storage/model_files')

config: TConfig = app.extra['CONFIG']

if os.getenv('DISABLE_CUSTOM_DOWNLOAD_URI', '0') == '1':
_set_app_config(app, config)
return True

for model_type in ('embedding', 'llm'):
model_name = _get_model_name_or_path(app.extra['CONFIG'], model_type)
if model_name is None:
Expand All @@ -259,7 +263,6 @@ def model_init(app: FastAPI) -> bool:
if not _model_exists(model_name):
return False

config: TConfig = app.extra['CONFIG']
_set_app_config(app, config)

return True
1 change: 1 addition & 0 deletions context_chat_backend/vectordb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# class name/index name is capitalized (user1 => User1) maybe because it is a class name,
# so the solution is to use Vector_user1 instead of user1
COLLECTION_NAME = lambda user_id: f'Vector_{user_id}'
USER_ID_FROM_COLLECTION = lambda collection: collection.split('_')[-1]


def get_vector_db(db_name: str) -> BaseVectorDB:
Expand Down
39 changes: 39 additions & 0 deletions context_chat_backend/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ class BaseVectorDB(ABC):
def __init__(self, embedding: Embeddings | None = None, **kwargs):
self.embedding = embedding

@abstractmethod
def get_users(self) -> list[str]:
'''
Returns a list of all user IDs.
Returns
-------
list[str]
List of user IDs.
'''

@abstractmethod
def get_user_client(
self,
Expand Down Expand Up @@ -175,3 +186,31 @@ def delete(self, user_id: str, metadata_key: str, values: list[str]) -> bool:
]

return self.delete_by_ids(user_id, ids)

def delete_for_all_users(self, metadata_key: str, values: list[str]) -> bool:
'''
Deletes all documents with the matching values for the given metadata key for all users.
Args
----
metadata_key: str
Metadata key to delete by.
values: list[str]
List of metadata values to match.
Returns
-------
bool
True if deletion is successful,
False otherwise
'''
if len(values) == 0:
return True

success = True

users = self.get_users()
for user_id in users:
success &= self.delete(user_id, metadata_key, values)

return success
8 changes: 7 additions & 1 deletion context_chat_backend/vectordb/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores import Chroma, VectorStore

from . import COLLECTION_NAME
from . import COLLECTION_NAME, USER_ID_FROM_COLLECTION
from .base import BaseVectorDB, MetadataFilter, TSearchDict

load_dotenv()
Expand All @@ -33,6 +33,12 @@ def __init__(self, embedding: Embeddings | None = None, **kwargs):
self.client = client
self.embedding = embedding

def get_users(self) -> list[str]:
if not self.client:
raise Exception('Error: Chromadb client not initialised')

return [USER_ID_FROM_COLLECTION(collection.name) for collection in self.client.list_collections()]

def setup_schema(self, user_id: str) -> None:
if not self.client:
raise Exception('Error: Chromadb client not initialised')
Expand Down
11 changes: 10 additions & 1 deletion context_chat_backend/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from weaviate import AuthApiKey, Client

from ..utils import value_of
from . import COLLECTION_NAME
from . import COLLECTION_NAME, USER_ID_FROM_COLLECTION
from .base import BaseVectorDB, MetadataFilter, TSearchDict

load_dotenv()
Expand Down Expand Up @@ -89,6 +89,15 @@ def __init__(self, embedding: Embeddings | None = None, **kwargs):
self.client = client
self.embedding = embedding

def get_users(self) -> list[str]:
if not self.client:
raise Exception('Error: Weaviate client not initialised')

return [
USER_ID_FROM_COLLECTION(klass.get('class', ''))
for klass in self.client.schema.get().get('classes', [])
]

def setup_schema(self, user_id: str) -> None:
if not self.client:
raise Exception('Error: Weaviate client not initialised')
Expand Down

0 comments on commit 5c4dfd1

Please sign in to comment.