Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for scoped context in query #13

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion context_chat_backend/chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .ingest import embed_sources
from .one_shot import process_query
from .one_shot import ScopeType, process_query, process_scoped_query

__all__ = [
'ScopeType',
'embed_sources',
'process_query',
'process_scoped_query',
]
60 changes: 53 additions & 7 deletions context_chat_backend/chain/one_shot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from enum import Enum
from logging import error as log_error

from langchain.llms.base import LLM

from ..vectordb import BaseVectorDB
Expand All @@ -9,28 +12,71 @@
'''


class ScopeType(Enum):
PROVIDER = 'provider'
SOURCE = 'source'


def process_query(
user_id: str,
vectordb: BaseVectorDB,
llm: LLM,
query: str,
use_context: bool = True,
ctx_limit: int = 5,
template: str = _LLM_TEMPLATE,
ctx_filter: dict | None = None,
template: str | None = None,
end_separator: str = '',
) -> tuple[str, set]:
) -> tuple[str, list[str]]:
if not use_context:
return llm.predict(query), set()
return llm.predict(query), []

user_client = vectordb.get_user_client(user_id)
if user_client is None:
return llm.predict(query), set()
return llm.predict(query), []

if ctx_filter is not None:
context_docs = user_client.similarity_search(query, k=ctx_limit, filter=ctx_filter)
else:
context_docs = user_client.similarity_search(query, k=ctx_limit)

context_docs = user_client.similarity_search(query, k=ctx_limit)
context_text = '\n\n'.join(f'{d.metadata.get("title")}\n{d.page_content}' for d in context_docs)

output = llm.predict(template.format(context=context_text, question=query)) \
output = llm.predict((template or _LLM_TEMPLATE).format(context=context_text, question=query)) \
.strip().rstrip(end_separator).strip()
unique_sources = {sources for d in context_docs if (sources := d.metadata.get('source'))}
unique_sources: list[str] = list({source for d in context_docs if (source := d.metadata.get('source'))})

return (output, unique_sources)


def process_scoped_query(
user_id: str,
vectordb: BaseVectorDB,
llm: LLM,
query: str,
scope_type: ScopeType,
scope_list: list[str],
ctx_limit: int = 5,
template: str | None = None,
end_separator: str = '',
) -> tuple[str, list[str]]:
ctx_filter = vectordb.get_metadata_filter([{
'metadata_key': scope_type.value,
'values': scope_list,
}])

if ctx_filter is None:
log_error(f'Error: could not get filter for (\nscope type: {scope_type}\n\
scope list: {scope_list}\n\nproceeding with an unscoped query')

return process_query(
user_id=user_id,
vectordb=vectordb,
llm=llm,
query=query,
use_context=True,
ctx_limit=ctx_limit,
ctx_filter=ctx_filter,
template=template,
end_separator=end_separator,
)
74 changes: 68 additions & 6 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from os import getenv
from typing import Annotated
from typing import Annotated, Any

from dotenv import load_dotenv
from fastapi import BackgroundTasks, Body, FastAPI, Request, UploadFile
from langchain.llms.base import LLM
from pydantic import BaseModel, FieldValidationInfo, field_validator

from .chain import embed_sources, process_query
from .chain import ScopeType, embed_sources, process_query, process_scoped_query
from .download import download_all_models
from .ocs_utils import AppAPIAuthMiddleware
from .utils import JSONResponse, enabled_guard, update_progress, value_of
Expand Down Expand Up @@ -190,6 +191,15 @@ def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5):
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

if value_of(userId) is None:
return JSONResponse('Empty User ID', 400)

if value_of(query) is None:
return JSONResponse('Empty query', 400)

if ctxLimit < 1:
return JSONResponse('Invalid context chunk limit', 400)

template = app.extra.get('LLM_TEMPLATE')
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

Expand All @@ -200,14 +210,66 @@ def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5):
query=query,
use_context=useContext,
ctx_limit=ctxLimit,
template=template,
end_separator=end_separator,
**({'template': template} if template else {}),
)

if output is None:
return JSONResponse('Error: check if the model specified supports the query type', 500)
return JSONResponse({
'output': output,
'sources': sources,
})


class ScopedQuery(BaseModel):
userId: str
query: str
scopeType: ScopeType
scopeList: list[str]
ctxLimit: int = 5

@field_validator('userId', 'query', 'scopeList', 'ctxLimit')
@classmethod
def check_empty_values(cls, value: Any, info: FieldValidationInfo):
if value_of(value) is None:
raise ValueError('Empty value for field', info.field_name)

return value

@field_validator('ctxLimit')
@classmethod
def at_least_one_context(cls, v: int):
if v < 1:
raise ValueError('Invalid context chunk limit')

return v

@app.post('/scopedQuery')
@enabled_guard(app)
def _(scopedQuery: ScopedQuery):
llm: LLM | None = app.extra.get('LLM_MODEL')
if llm is None:
return JSONResponse('Error: LLM not initialised', 500)

db: BaseVectorDB | None = app.extra.get('VECTOR_DB')
if db is None:
return JSONResponse('Error: VectorDB not initialised', 500)

template = app.extra.get('LLM_TEMPLATE')
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

(output, sources) = process_scoped_query(
user_id=scopedQuery.userId,
vectordb=db,
llm=llm,
query=scopedQuery.query,
ctx_limit=scopedQuery.ctxLimit,
template=template,
end_separator=end_separator,
scope_type=scopedQuery.scopeType,
scope_list=scopedQuery.scopeList,
)

return JSONResponse({
'output': output,
'sources': list(sources),
'sources': sources,
})
4 changes: 2 additions & 2 deletions context_chat_backend/vectordb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from importlib import import_module

from .base import BaseVectorDB
from .base import BaseVectorDB, MetadataFilter

vector_dbs = ['weaviate', 'chroma']

__all__ = ['get_vector_db', 'vector_dbs', 'BaseVectorDB', 'COLLECTION_NAME']
__all__ = ['get_vector_db', 'vector_dbs', 'BaseVectorDB', 'COLLECTION_NAME', 'MetadataFilter']


# class name/index name is capitalized (user1 => User1) maybe because it is a class name,
Expand Down
21 changes: 21 additions & 0 deletions context_chat_backend/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class TSearchObject(TypedDict):
TSearchDict = dict[str, TSearchObject]


class MetadataFilter(TypedDict):
metadata_key: str
values: list[str]


class BaseVectorDB(ABC):
client: Any = None
embedding: Any = None
Expand Down Expand Up @@ -59,6 +64,22 @@ def setup_schema(self, user_id: str) -> None:
None
'''

@abstractmethod
def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None:
'''
Returns the metadata filter for the given filters.

Args
----
filters: tuple[MetadataFilter]
Tuple of metadata filters.

Returns
-------
dict
Metadata filter dictionary.
'''

@abstractmethod
def get_objects_from_metadata(
self,
Expand Down
29 changes: 25 additions & 4 deletions context_chat_backend/vectordb/chroma.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from logging import error as log_error
from os import getenv

from chromadb import Client, Where
from chromadb import Client
from chromadb.config import Settings
from dotenv import load_dotenv
from langchain.schema.embeddings import Embeddings
from langchain.vectorstores import Chroma, VectorStore

from . import COLLECTION_NAME
from .base import BaseVectorDB, TSearchDict
from .base import BaseVectorDB, MetadataFilter, TSearchDict

load_dotenv()

Expand Down Expand Up @@ -59,6 +59,19 @@ def get_user_client(
embedding_function=em,
)

def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None:
if len(filters) == 0:
return None

if len(filters) == 1:
return { filters[0]['metadata_key']: { '$in': filters[0]['values'] } }

return {
'$or': [{
f['metadata_key']: { '$in': f['values'] }
} for f in filters]
}

def get_objects_from_metadata(
self,
user_id: str,
Expand All @@ -72,10 +85,18 @@ def get_objects_from_metadata(

self.setup_schema(user_id)

if len(values) == 0:
try:
data_filter = self.get_metadata_filter([{
'metadata_key': metadata_key,
'values': values,
}])
except KeyError as e:
# todo: info instead of error
log_error(f'Error: Chromadb filter error: {e}')
return {}

data_filter: Where = { metadata_key: { '$in': values } } # type: ignore
if data_filter is None:
return {}

try:
results = self.client.get_collection(COLLECTION_NAME(user_id)).get(
Expand Down
38 changes: 31 additions & 7 deletions context_chat_backend/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..utils import value_of
from . import COLLECTION_NAME
from .base import BaseVectorDB, TSearchDict
from .base import BaseVectorDB, MetadataFilter, TSearchDict

load_dotenv()

Expand Down Expand Up @@ -125,6 +125,26 @@ def get_user_client(

return weaviate_obj

def get_metadata_filter(self, filters: list[MetadataFilter]) -> dict | None:
if len(filters) == 0:
return None

if len(filters) == 1:
return {
'path': filters[0]['metadata_key'],
'operator': 'ContainsAny',
'valueTextList': filters[0]['values'],
}

return {
'operator': 'Or',
'operands': [{
'path': f['metadata_key'],
'operator': 'ContainsAny',
'valueTextList': f['values'],
} for f in filters]
}

def get_objects_from_metadata(
self,
user_id: str,
Expand All @@ -138,14 +158,18 @@ def get_objects_from_metadata(

self.setup_schema(user_id)

if len(values) == 0:
try:
data_filter = self.get_metadata_filter([{
'metadata_key': metadata_key,
'values': values,
}])
except KeyError as e:
# todo: info instead of error
log_error(f'Error: Chromadb filter error: {e}')
return {}

data_filter = {
'path': [metadata_key],
'operator': 'ContainsAny',
'valueTextList': values,
}
if data_filter is None:
return {}

results = self.client.query \
.get(COLLECTION_NAME(user_id), [metadata_key, 'modified']) \
Expand Down