Skip to content

Commit

Permalink
fix: unavailable model on startup
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Oct 9, 2024
1 parent 1f970a1 commit d1e6368
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions app/helpers/_modelclients.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from app.schemas.embeddings import Embeddings
from app.schemas.models import Model, Models
from app.utils.config import LOGGER
from app.utils.exceptions import ContextLengthExceededException, ModelNotFoundException, ModelNotAvailableException
from app.utils.exceptions import ContextLengthExceededException, ModelNotAvailableException, ModelNotFoundException
from app.utils.variables import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE


Expand Down Expand Up @@ -66,12 +66,20 @@ def get_models_list(self, *args, **kwargs):
def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True):
headers = {"Authorization": f"Bearer {self.api_key}"}
prompt = "\n".join([message["role"] + ": " + message["content"] for message in messages])
data = {"model": self.id, "prompt": prompt, "add_special_tokens": add_special_tokens}

if self.type == LANGUAGE_MODEL_TYPE:
data = {"model": self.id, "prompt": prompt, "add_special_tokens": add_special_tokens}
elif self.type == EMBEDDINGS_MODEL_TYPE:
data = {"inputs": prompt, "add_special_tokens": add_special_tokens}

response = requests.post(str(self.base_url).replace("/v1/", "/tokenize"), json=data, headers=headers)
response.raise_for_status()
response = response.json()

return response.json()["count"] <= response.json()["max_model_len"]
if self.type == LANGUAGE_MODEL_TYPE:
return response["count"] <= response["max_model_len"]
elif self.type == EMBEDDINGS_MODEL_TYPE:
return len(response[0]) <= self.max_model_len


def create_embeddings(self, *args, **kwargs):
Expand All @@ -91,17 +99,20 @@ def create_embeddings(self, *args, **kwargs):
class ModelClient(OpenAI):
DEFAULT_TIMEOUT = 10

def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], search_internet: bool = False, *args, **kwargs):
def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *args, **kwargs):
"""
ModelClient class extends OpenAI class to support custom methods.
"""
super().__init__(timeout=self.DEFAULT_TIMEOUT, *args, **kwargs)
self.type = type
self.id = None
self.owned_by = None
self.created = None

# set attributes for unavailable models
self.id = ""
self.owned_by = ""
self.created = round(time.time())
self.max_model_len = None
self.search_internet = search_internet

# set real attributes if model is available
self.models.list = partial(get_models_list, self)
response = self.models.list()

Expand All @@ -110,9 +121,7 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], sea
self.vector_size = len(response.data[0].embedding)
self.embeddings.create = partial(create_embeddings, self)

# @ TODO : extends to embeddings models
if self.type == LANGUAGE_MODEL_TYPE:
self.check_context_length = partial(check_context_length, self)
self.check_context_length = partial(check_context_length, self)


class ModelClients(dict):
Expand All @@ -121,16 +130,16 @@ class ModelClients(dict):
"""

def __init__(self, config: Config):
for model in config.models:
model = ModelClient(base_url=model.url, api_key=model.key, type=model.type, search_internet=model.search_internet)
for model_config in config.models:
model = ModelClient(base_url=model_config.url, api_key=model_config.key, type=model_config.type)
if model.status == "unavailable":
LOGGER.info(f"error to request the model API on {model.url}, skipping.")
LOGGER.error(f"unavailable model API on {model_config.url}, skipping.")
continue
self.__setitem__(model.id, model)

if model.search_internet and model.type == EMBEDDINGS_MODEL_TYPE:
if model_config.search_internet and model_config.type == EMBEDDINGS_MODEL_TYPE:
self.SEARCH_INTERNET_EMBEDDINGS_MODEL_ID = model.id
if model.search_internet and model.type == LANGUAGE_MODEL_TYPE:
if model_config.search_internet and model_config.type == LANGUAGE_MODEL_TYPE:
self.SEARCH_INTERNET_LANGUAGE_MODEL_ID = model.id

if "SEARCH_INTERNET_EMBEDDINGS_MODEL_ID" not in self.__dict__:
Expand Down

0 comments on commit d1e6368

Please sign in to comment.