Skip to content

Commit

Permalink
ported init and persistent storage changes
Browse files Browse the repository at this point in the history
- use AppAPI's persistent volume
- add /init route and enabled guard
- hash check downloaded model files

Signed-off-by: Anupam Kumar <[email protected]>
  • Loading branch information
kyteinsky committed Feb 22, 2024
1 parent 6d0ddcb commit 4c49f06
Show file tree
Hide file tree
Showing 17 changed files with 356 additions and 215 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
.venv/
__pycache__/
.env
model_files/*
vector_db_data/*
persistent_storage/*
.vscode/
6 changes: 0 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
FROM python:3.11-bookworm

VOLUME /app/model_files
VOLUME /app/vector_db_data
ENV VECTORDB_DIR /app/vector_db_data
ENV SENTENCE_TRANSFORMERS_HOME /app/model_files
ENV TRANSFORMERS_CACHE /app/model_files

RUN apt update && apt install -y --no-install-recommends pandoc

WORKDIR /app
Expand Down
7 changes: 3 additions & 4 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
vectordb:
chroma:
is_persistent: True
persist_directory: ./vector_db_data
# chroma_server_host:
# chroma_server_http_port
# chroma_server_ssl_enabled
Expand All @@ -16,7 +15,7 @@ embedding:
device: cpu

llama:
model_path: model_files/dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
model_path: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
n_batch: 16
n_ctx: 2048

Expand All @@ -28,14 +27,14 @@ embedding:

llm:
llama:
model_path: model_files/dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
model_path: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
n_batch: 10
n_ctx: 4096
template: "<|im_start|> system \nYou're an AI assistant good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n"
end_separator: <|im_end|>

ctransformer:
model: model_files/dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
model: dolphin-2.2.1-mistral-7b.Q5_K_M.gguf
template: "<|im_start|> system \nYou're an AI assistant good at finding relevant context from documents to answer questions provided by the user. <|im_end|>\n<|im_start|> user\nUse the following documents as context to answer the question at the end. REMEMBER to excersice source critisicm as the documents are returned by a search provider that can return unrelated documents.\n\nSTART OF CONTEXT: \n{context} \n\nEND OF CONTEXT!\n\nIf you don't know the answer or are unsure, just say that you don't know, don't try to make up an answer. Don't mention the context in your answer but rather just answer the question directly. \nQuestion: {question} Let's think this step-by-step. \n<|im_end|>\n<|im_start|> assistant\n"
end_separator: <|im_end|>

Expand Down
74 changes: 22 additions & 52 deletions context_chat_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,40 @@
from os import getenv
import os

from dotenv import load_dotenv
import uvicorn

from .config_parser import get_config
from .controller import app
from .download import download_all_models
from .models import models
from .download import model_init
from .utils import to_int
from .vectordb import vector_dbs

load_dotenv()

__all__ = ['create_server', 'vector_dbs', 'models']
__all__ = ['app', 'to_int']


def create_server(config: dict[str, tuple[str, dict]]):
def _setup_env_vars():
'''
Creates a FastAPI server with the given config.
Args
----
config: dict
A dictionary containing the services to be deployed.
Sets up the environment variables for persistent storage.
'''
if getenv('DISABLE_CUSTOM_DOWNLOAD_URI', '0') != '1':
if (model_name := download_all_models(config)) is not None:
raise Exception(f'Error: Model download failed for {model_name}')

app.extra['CONFIG'] = config

if config.get('embedding'):
from .models import init_model

model = init_model('embedding', config.get('embedding'))
app.extra['EMBEDDING_MODEL'] = model

if config.get('vectordb'):
from .vectordb import get_vector_db
persistent_storage = os.getenv('APP_PERSISTENT_STORAGE', 'persistent_storage')

client_klass = get_vector_db(config.get('vectordb')[0])
vector_db_dir = os.path.join(persistent_storage, 'vector_db_data')
if not os.path.exists(vector_db_dir):
os.makedirs(vector_db_dir, 0o750, True)

if app.extra.get('EMBEDDING_MODEL') is not None:
app.extra['VECTOR_DB'] = client_klass(app.extra['EMBEDDING_MODEL'], **config.get('vectordb')[1])
else:
app.extra['VECTOR_DB'] = client_klass(**config.get('vectordb')[1])
model_dir = os.path.join(persistent_storage, 'model_files')
if not os.path.exists(model_dir):
os.makedirs(model_dir, 0o750, True)

if config.get('llm'):
from .models import init_model
os.environ['APP_PERSISTENT_STORAGE'] = persistent_storage
os.environ['VECTORDB_DIR'] = vector_db_dir
os.environ['MODEL_DIR'] = model_dir
os.environ['SENTENCE_TRANSFORMERS_HOME'] = os.getenv('SENTENCE_TRANSFORMERS_HOME', model_dir)
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', model_dir)

llm_name, llm_config = config.get('llm')
app.extra['LLM_TEMPLATE'] = llm_config.pop('template', '')
app.extra['LLM_END_SEPARATOR'] = llm_config.pop('end_separator', '')

model = init_model('llm', (llm_name, llm_config))
app.extra['LLM_MODEL'] = model
_setup_env_vars()

uvicorn.run(
app=app,
host=getenv('APP_HOST', '0.0.0.0'),
port=to_int(getenv('APP_PORT'), 9000),
http='h11',
interface='asgi3',
log_level=('warning', 'trace')[getenv('DEBUG', '0') == '1'],
use_colors=True,
limit_concurrency=100,
backlog=100,
timeout_keep_alive=10,
h11_max_incomplete_event_size=5 * 1024 * 1024, # 5MB
)
app.extra['CONFIG'] = get_config()
app.extra['ENABLED'] = model_init(app)
print('App', 'enabled' if app.extra['ENABLED'] else 'disabled', 'at startup', flush=True)
58 changes: 58 additions & 0 deletions context_chat_backend/config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pprint import pprint

from ruamel.yaml import YAML

from .models import models
from .vectordb import vector_dbs


def _first_in_list(
input_dict: dict[str, dict],
supported_list: list[str]
) -> tuple[str, dict] | None:
'''
Find the first matching item in the input list from the supported list.
This is done to find the first supported item in the config file.
'''
for input_item, value in input_dict.items():
if input_item in supported_list:
return (input_item, value or {})

return None


def get_config(file_path: str = 'config.yaml') -> dict[str, tuple[str, dict]]:
'''
Get the config from the given file path (relative to the root directory).
'''
with open(file_path) as f:
try:
yaml = YAML(typ='safe')
config: dict = yaml.load(f)
except Exception as e:
raise AssertionError('Error: could not load config from', file_path, 'file') from e

selected_config = {
'vectordb': _first_in_list(config.get('vectordb', {}), vector_dbs),
'embedding': _first_in_list(config.get('embedding', {}), models['embedding']),
'llm': _first_in_list(config.get('llm', {}), models['llm']),
}

if not selected_config['vectordb']:
raise AssertionError(
f'Error: vectordb should be at least one of {vector_dbs} in the config file'
)

if not selected_config['embedding']:
raise AssertionError(
f'Error: embedding model should be at least one of {models["embedding"]} in the config file'
)

if not selected_config['llm']:
raise AssertionError(
f'Error: llm model should be at least one of {models["llm"]} in the config file'
)

pprint(f'Selected config: {selected_config}')

return selected_config
53 changes: 24 additions & 29 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from os import getenv
from typing import Annotated, Any
from typing import Annotated

from dotenv import load_dotenv
from fastapi import Body, FastAPI, Request, UploadFile
from fastapi.responses import JSONResponse as FastAPIJSONResponse
from fastapi import Body, FastAPI, Request, UploadFile, BackgroundTasks
from langchain.llms.base import LLM

from .chain import embed_sources, process_query
from .download import download_all_models
from .ocs_utils import AppAPIAuthMiddleware
from .utils import value_of
from .utils import enabled_guard, JSONResponse, update_progress, value_of
from .vectordb import BaseVectorDB

load_dotenv()
Expand All @@ -22,30 +22,6 @@
app.add_middleware(AppAPIAuthMiddleware)


def JSONResponse(
content: Any = 'ok',
status_code: int = 200,
**kwargs
) -> FastAPIJSONResponse:
'''
Wrapper for FastAPI JSONResponse
'''
if isinstance(content, str):
if status_code >= 400:
return FastAPIJSONResponse(
content={ 'error': content },
status_code=status_code,
**kwargs,
)
return FastAPIJSONResponse(
content={ 'message': content },
status_code=status_code,
**kwargs,
)

return FastAPIJSONResponse(content, status_code, **kwargs)


@app.get('/')
def _(request: Request):
'''
Expand All @@ -56,13 +32,15 @@ def _(request: Request):

# TODO: for testing, remove later
@app.get('/world')
@enabled_guard(app)
def _(query: str | None = None):
em = app.extra.get('EMBEDDING_MODEL')
return em.embed_query(query if query is not None else 'what is an apple?')


# TODO: for testing, remove later
@app.get('/vectors')
@enabled_guard(app)
def _(userId: str):
from chromadb import ClientAPI
from .utils import COLLECTION_NAME
Expand All @@ -78,6 +56,7 @@ def _(userId: str):

# TODO: for testing, remove later
@app.get('/search')
@enabled_guard(app)
def _(userId: str, keyword: str):
from chromadb import ClientAPI
from .utils import COLLECTION_NAME
Expand All @@ -96,7 +75,8 @@ def _(userId: str, keyword: str):

@app.put('/enabled')
def _(enabled: bool):
print(f'{enabled:}')
app.extra['ENABLED'] = enabled
print('App', 'enabled' if enabled else 'disabled', flush=True)
return JSONResponse(content={'error': ''}, status_code=200)


Expand All @@ -106,7 +86,19 @@ def _():
return JSONResponse(content={'status': 'ok'}, status_code=200)


@app.post('/init')
def _(bg_tasks: BackgroundTasks):
if not app.extra.get('ENABLED', False):
bg_tasks.add_task(download_all_models, app)
return JSONResponse(content={}, status_code=200)

update_progress(100)
print('App already initialised', flush=True)
return JSONResponse(content={}, status_code=200)


@app.post('/deleteSources')
@enabled_guard(app)
def _(userId: Annotated[str, Body()], sourceNames: Annotated[list[str], Body()]):
sourceNames = [source.strip() for source in sourceNames if source.strip() != '']

Expand Down Expand Up @@ -139,6 +131,7 @@ def _(userId: Annotated[str, Body()], sourceNames: Annotated[list[str], Body()])


@app.post('/deleteMatchingSources')
@enabled_guard(app)
def _(userId: Annotated[str, Body()], keyword: Annotated[str, Body()]):
db: BaseVectorDB = app.extra.get('VECTOR_DB')

Expand Down Expand Up @@ -166,6 +159,7 @@ def _(userId: Annotated[str, Body()], keyword: Annotated[str, Body()]):


@app.put('/loadSources')
@enabled_guard(app)
def _(sources: list[UploadFile]):
if len(sources) == 0:
return JSONResponse('No sources provided', 400)
Expand All @@ -191,6 +185,7 @@ def _(sources: list[UploadFile]):


@app.get('/query')
@enabled_guard(app)
def _(userId: str, query: str, useContext: bool = True, ctxLimit: int = 5):
llm: LLM = app.extra.get('LLM_MODEL')
if llm is None:
Expand Down
Loading

0 comments on commit 4c49f06

Please sign in to comment.