-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ported init and persistent storage changes
- use AppAPI's persistent volume - add /init route and enabled guard - hash check downloaded model files Signed-off-by: Anupam Kumar <[email protected]>
- Loading branch information
Showing
17 changed files
with
356 additions
and
215 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
.venv/ | ||
__pycache__/ | ||
.env | ||
model_files/* | ||
vector_db_data/* | ||
persistent_storage/* | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,40 @@ | ||
from os import getenv | ||
import os | ||
|
||
from dotenv import load_dotenv | ||
import uvicorn | ||
|
||
from .config_parser import get_config | ||
from .controller import app | ||
from .download import download_all_models | ||
from .models import models | ||
from .download import model_init | ||
from .utils import to_int | ||
from .vectordb import vector_dbs | ||
|
||
load_dotenv() | ||
|
||
__all__ = ['create_server', 'vector_dbs', 'models'] | ||
__all__ = ['app', 'to_int'] | ||
|
||
|
||
def create_server(config: dict[str, tuple[str, dict]]): | ||
def _setup_env_vars(): | ||
''' | ||
Creates a FastAPI server with the given config. | ||
Args | ||
---- | ||
config: dict | ||
A dictionary containing the services to be deployed. | ||
Sets up the environment variables for persistent storage. | ||
''' | ||
if getenv('DISABLE_CUSTOM_DOWNLOAD_URI', '0') != '1': | ||
if (model_name := download_all_models(config)) is not None: | ||
raise Exception(f'Error: Model download failed for {model_name}') | ||
|
||
app.extra['CONFIG'] = config | ||
|
||
if config.get('embedding'): | ||
from .models import init_model | ||
|
||
model = init_model('embedding', config.get('embedding')) | ||
app.extra['EMBEDDING_MODEL'] = model | ||
|
||
if config.get('vectordb'): | ||
from .vectordb import get_vector_db | ||
persistent_storage = os.getenv('APP_PERSISTENT_STORAGE', 'persistent_storage') | ||
|
||
client_klass = get_vector_db(config.get('vectordb')[0]) | ||
vector_db_dir = os.path.join(persistent_storage, 'vector_db_data') | ||
if not os.path.exists(vector_db_dir): | ||
os.makedirs(vector_db_dir, 0o750, True) | ||
|
||
if app.extra.get('EMBEDDING_MODEL') is not None: | ||
app.extra['VECTOR_DB'] = client_klass(app.extra['EMBEDDING_MODEL'], **config.get('vectordb')[1]) | ||
else: | ||
app.extra['VECTOR_DB'] = client_klass(**config.get('vectordb')[1]) | ||
model_dir = os.path.join(persistent_storage, 'model_files') | ||
if not os.path.exists(model_dir): | ||
os.makedirs(model_dir, 0o750, True) | ||
|
||
if config.get('llm'): | ||
from .models import init_model | ||
os.environ['APP_PERSISTENT_STORAGE'] = persistent_storage | ||
os.environ['VECTORDB_DIR'] = vector_db_dir | ||
os.environ['MODEL_DIR'] = model_dir | ||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = os.getenv('SENTENCE_TRANSFORMERS_HOME', model_dir) | ||
os.environ['TRANSFORMERS_CACHE'] = os.getenv('TRANSFORMERS_CACHE', model_dir) | ||
|
||
llm_name, llm_config = config.get('llm') | ||
app.extra['LLM_TEMPLATE'] = llm_config.pop('template', '') | ||
app.extra['LLM_END_SEPARATOR'] = llm_config.pop('end_separator', '') | ||
|
||
model = init_model('llm', (llm_name, llm_config)) | ||
app.extra['LLM_MODEL'] = model | ||
_setup_env_vars() | ||
|
||
uvicorn.run( | ||
app=app, | ||
host=getenv('APP_HOST', '0.0.0.0'), | ||
port=to_int(getenv('APP_PORT'), 9000), | ||
http='h11', | ||
interface='asgi3', | ||
log_level=('warning', 'trace')[getenv('DEBUG', '0') == '1'], | ||
use_colors=True, | ||
limit_concurrency=100, | ||
backlog=100, | ||
timeout_keep_alive=10, | ||
h11_max_incomplete_event_size=5 * 1024 * 1024, # 5MB | ||
) | ||
app.extra['CONFIG'] = get_config() | ||
app.extra['ENABLED'] = model_init(app) | ||
print('App', 'enabled' if app.extra['ENABLED'] else 'disabled', 'at startup', flush=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from pprint import pprint | ||
|
||
from ruamel.yaml import YAML | ||
|
||
from .models import models | ||
from .vectordb import vector_dbs | ||
|
||
|
||
def _first_in_list( | ||
input_dict: dict[str, dict], | ||
supported_list: list[str] | ||
) -> tuple[str, dict] | None: | ||
''' | ||
Find the first matching item in the input list from the supported list. | ||
This is done to find the first supported item in the config file. | ||
''' | ||
for input_item, value in input_dict.items(): | ||
if input_item in supported_list: | ||
return (input_item, value or {}) | ||
|
||
return None | ||
|
||
|
||
def get_config(file_path: str = 'config.yaml') -> dict[str, tuple[str, dict]]: | ||
''' | ||
Get the config from the given file path (relative to the root directory). | ||
''' | ||
with open(file_path) as f: | ||
try: | ||
yaml = YAML(typ='safe') | ||
config: dict = yaml.load(f) | ||
except Exception as e: | ||
raise AssertionError('Error: could not load config from', file_path, 'file') from e | ||
|
||
selected_config = { | ||
'vectordb': _first_in_list(config.get('vectordb', {}), vector_dbs), | ||
'embedding': _first_in_list(config.get('embedding', {}), models['embedding']), | ||
'llm': _first_in_list(config.get('llm', {}), models['llm']), | ||
} | ||
|
||
if not selected_config['vectordb']: | ||
raise AssertionError( | ||
f'Error: vectordb should be at least one of {vector_dbs} in the config file' | ||
) | ||
|
||
if not selected_config['embedding']: | ||
raise AssertionError( | ||
f'Error: embedding model should be at least one of {models["embedding"]} in the config file' | ||
) | ||
|
||
if not selected_config['llm']: | ||
raise AssertionError( | ||
f'Error: llm model should be at least one of {models["llm"]} in the config file' | ||
) | ||
|
||
pprint(f'Selected config: {selected_config}') | ||
|
||
return selected_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.