From 34d2f09b8abf029155192280dc0301f12e4b62f7 Mon Sep 17 00:00:00 2001 From: leoguillaume Date: Wed, 9 Oct 2024 12:08:03 +0200 Subject: [PATCH] fix: max context length --- app/helpers/_modelclients.py | 13 +++++++------ app/schemas/chat.py | 2 +- app/schemas/models.py | 3 ++- app/tests/test_chat.py | 16 ++++++++-------- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/app/helpers/_modelclients.py b/app/helpers/_modelclients.py index d84cb6a..5526f53 100644 --- a/app/helpers/_modelclients.py +++ b/app/helpers/_modelclients.py @@ -34,7 +34,7 @@ def get_models_list(self, *args, **kwargs): self.id = response["id"] self.owned_by = response.get("owned_by", "") self.created = response.get("created", round(time.time())) - self.max_model_len = response.get("max_model_len", None) + self.max_context_length = response.get("max_model_len", None) elif self.type == EMBEDDINGS_MODEL_TYPE: endpoint = str(self.base_url).replace("/v1/", "/info") @@ -43,7 +43,7 @@ def get_models_list(self, *args, **kwargs): self.id = response["model_id"] self.owned_by = "huggingface-text-embeddings-inference" self.created = round(time.time()) - self.max_model_len = response.get("max_input_length", None) + self.max_context_length = response.get("max_input_length", None) self.status = "available" @@ -55,7 +55,7 @@ def get_models_list(self, *args, **kwargs): object="model", owned_by=self.owned_by, created=self.created, - max_model_len=self.max_model_len, + max_context_length=self.max_context_length, type=self.type, status=self.status, ) @@ -64,6 +64,7 @@ def get_models_list(self, *args, **kwargs): def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True): + # TODO: remove this methode and use better context length handling headers = {"Authorization": f"Bearer {self.api_key}"} prompt = "\n".join([message["role"] + ": " + message["content"] for message in messages]) @@ -77,9 +78,9 @@ def check_context_length(self, messages: List[Dict[str, str]], add_special_token response = response.json() if self.type == LANGUAGE_MODEL_TYPE: - return response["count"] <= response["max_model_len"] + return response["count"] <= self.max_context_length elif self.type == EMBEDDINGS_MODEL_TYPE: - return len(response[0]) <= self.max_model_len + return len(response[0]) <= self.max_context_length def create_embeddings(self, *args, **kwargs): @@ -110,7 +111,7 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *ar self.id = "" self.owned_by = "" self.created = round(time.time()) - self.max_model_len = None + self.max_context_length = None # set real attributes if model is available self.models.list = partial(get_models_list, self) diff --git a/app/schemas/chat.py b/app/schemas/chat.py index 269f6f5..06bbc1a 100644 --- a/app/schemas/chat.py +++ b/app/schemas/chat.py @@ -46,7 +46,7 @@ def validate_model(cls, value): if not clients.models[value["model"]].check_context_length(messages=value["messages"]): raise ContextLengthExceededException() - if value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length: + if "max_tokens" in value and value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length: raise MaxTokensExceededException() return value diff --git a/app/schemas/models.py b/app/schemas/models.py index 8b25518..d9dcfed 100644 --- a/app/schemas/models.py +++ b/app/schemas/models.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import List, Literal, Optional from openai.types import Model from pydantic import BaseModel @@ -7,6 +7,7 @@ class Model(Model): + max_context_length: Optional[int] = None type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE] status: Literal["available", "unavailable"] diff --git a/app/tests/test_chat.py b/app/tests/test_chat.py index 7351414..69c6728 100644 --- a/app/tests/test_chat.py +++ b/app/tests/test_chat.py @@ -15,11 +15,11 @@ def setup(args, session_user): response_json = response.json() model = [model for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE][0] MODEL_ID = model["id"] - MAX_MODEL_LEN = model["max_model_len"] + MAX_CONTEXT_LENGTH = model["max_context_length"] logging.info(f"test model ID: {MODEL_ID}") - logging.info(f"test max model len: {MAX_MODEL_LEN}") + logging.info(f"test max context length: {MAX_CONTEXT_LENGTH}") - yield MODEL_ID, MAX_MODEL_LEN + yield MODEL_ID, MAX_CONTEXT_LENGTH @pytest.mark.usefixtures("args", "session_user", "setup") @@ -78,7 +78,7 @@ def test_chat_completions_unknown_params(self, args, session_user, setup): assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code})" def test_chat_completions_max_tokens_too_large(self, args, session_user, setup): - MODEL_ID, MAX_MODEL_LEN = setup + MODEL_ID, MAX_CONTEXT_LENGTH = setup prompt = "test" params = { @@ -86,21 +86,21 @@ def test_chat_completions_max_tokens_too_large(self, args, session_user, setup): "messages": [{"role": "user", "content": prompt}], "stream": True, "n": 1, - "max_tokens": 1000000, + "max_tokens": MAX_CONTEXT_LENGTH + 10, } response = session_user.post(f"{args['base_url']}/chat/completions", json=params) assert response.status_code == 422, f"error: retrieve chat completions ({response.status_code})" def test_chat_completions_context_too_large(self, args, session_user, setup): - MODEL_ID, MAX_MODEL_LEN = setup + MODEL_ID, MAX_CONTEXT_LENGTH = setup - prompt = "test" * (MAX_MODEL_LEN + 100) + prompt = "test" * (MAX_CONTEXT_LENGTH + 10) params = { "model": MODEL_ID, "messages": [{"role": "user", "content": prompt}], "stream": True, "n": 1, - "max_tokens": 1000000, + "max_tokens": 10, } response = session_user.post(f"{args['base_url']}/chat/completions", json=params) assert response.status_code == 413, f"error: retrieve chat completions ({response.status_code})"