Skip to content

Commit

Permalink
fix: max context length
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Oct 9, 2024
1 parent 5c8dda5 commit 34d2f09
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
13 changes: 7 additions & 6 deletions app/helpers/_modelclients.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_models_list(self, *args, **kwargs):
self.id = response["id"]
self.owned_by = response.get("owned_by", "")
self.created = response.get("created", round(time.time()))
self.max_model_len = response.get("max_model_len", None)
self.max_context_length = response.get("max_model_len", None)

elif self.type == EMBEDDINGS_MODEL_TYPE:
endpoint = str(self.base_url).replace("/v1/", "/info")
Expand All @@ -43,7 +43,7 @@ def get_models_list(self, *args, **kwargs):
self.id = response["model_id"]
self.owned_by = "huggingface-text-embeddings-inference"
self.created = round(time.time())
self.max_model_len = response.get("max_input_length", None)
self.max_context_length = response.get("max_input_length", None)

self.status = "available"

Expand All @@ -55,7 +55,7 @@ def get_models_list(self, *args, **kwargs):
object="model",
owned_by=self.owned_by,
created=self.created,
max_model_len=self.max_model_len,
max_context_length=self.max_context_length,
type=self.type,
status=self.status,
)
Expand All @@ -64,6 +64,7 @@ def get_models_list(self, *args, **kwargs):


def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True):
# TODO: remove this methode and use better context length handling
headers = {"Authorization": f"Bearer {self.api_key}"}
prompt = "\n".join([message["role"] + ": " + message["content"] for message in messages])

Expand All @@ -77,9 +78,9 @@ def check_context_length(self, messages: List[Dict[str, str]], add_special_token
response = response.json()

if self.type == LANGUAGE_MODEL_TYPE:
return response["count"] <= response["max_model_len"]
return response["count"] <= self.max_context_length
elif self.type == EMBEDDINGS_MODEL_TYPE:
return len(response[0]) <= self.max_model_len
return len(response[0]) <= self.max_context_length


def create_embeddings(self, *args, **kwargs):
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], *ar
self.id = ""
self.owned_by = ""
self.created = round(time.time())
self.max_model_len = None
self.max_context_length = None

# set real attributes if model is available
self.models.list = partial(get_models_list, self)
Expand Down
2 changes: 1 addition & 1 deletion app/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def validate_model(cls, value):
if not clients.models[value["model"]].check_context_length(messages=value["messages"]):
raise ContextLengthExceededException()

if value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length:
if "max_tokens" in value and value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length:
raise MaxTokensExceededException()
return value

Expand Down
3 changes: 2 additions & 1 deletion app/schemas/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Literal
from typing import List, Literal, Optional

from openai.types import Model
from pydantic import BaseModel
Expand All @@ -7,6 +7,7 @@


class Model(Model):
max_context_length: Optional[int] = None
type: Literal[LANGUAGE_MODEL_TYPE, EMBEDDINGS_MODEL_TYPE]
status: Literal["available", "unavailable"]

Expand Down
16 changes: 8 additions & 8 deletions app/tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def setup(args, session_user):
response_json = response.json()
model = [model for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE][0]
MODEL_ID = model["id"]
MAX_MODEL_LEN = model["max_model_len"]
MAX_CONTEXT_LENGTH = model["max_context_length"]
logging.info(f"test model ID: {MODEL_ID}")
logging.info(f"test max model len: {MAX_MODEL_LEN}")
logging.info(f"test max context length: {MAX_CONTEXT_LENGTH}")

yield MODEL_ID, MAX_MODEL_LEN
yield MODEL_ID, MAX_CONTEXT_LENGTH


@pytest.mark.usefixtures("args", "session_user", "setup")
Expand Down Expand Up @@ -78,29 +78,29 @@ def test_chat_completions_unknown_params(self, args, session_user, setup):
assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code})"

def test_chat_completions_max_tokens_too_large(self, args, session_user, setup):
MODEL_ID, MAX_MODEL_LEN = setup
MODEL_ID, MAX_CONTEXT_LENGTH = setup

prompt = "test"
params = {
"model": MODEL_ID,
"messages": [{"role": "user", "content": prompt}],
"stream": True,
"n": 1,
"max_tokens": 1000000,
"max_tokens": MAX_CONTEXT_LENGTH + 10,
}
response = session_user.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 422, f"error: retrieve chat completions ({response.status_code})"

def test_chat_completions_context_too_large(self, args, session_user, setup):
MODEL_ID, MAX_MODEL_LEN = setup
MODEL_ID, MAX_CONTEXT_LENGTH = setup

prompt = "test" * (MAX_MODEL_LEN + 100)
prompt = "test" * (MAX_CONTEXT_LENGTH + 10)
params = {
"model": MODEL_ID,
"messages": [{"role": "user", "content": prompt}],
"stream": True,
"n": 1,
"max_tokens": 1000000,
"max_tokens": 10,
}
response = session_user.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 413, f"error: retrieve chat completions ({response.status_code})"

0 comments on commit 34d2f09

Please sign in to comment.