Skip to content

Commit

Permalink
feat: exceptions handler
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Oct 8, 2024
1 parent 5a1ada9 commit 693b97c
Show file tree
Hide file tree
Showing 16 changed files with 101 additions and 61 deletions.
9 changes: 1 addition & 8 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.variables import LANGUAGE_MODEL_TYPE
from app.utils.exceptions import WrongModelTypeException, ContextLengthExceededException


router = APIRouter()

Expand All @@ -22,15 +21,9 @@ async def chat_completions(request: ChatCompletionRequest, user: User = Security

request = dict(request)
client = clients.models[request["model"]]
if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()

url = f"{client.base_url}chat/completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

if not client.check_context_length(model=request["model"], messages=request["messages"]):
raise ContextLengthExceededException()

# non stream case
if not request["stream"]:
async with httpx.AsyncClient(timeout=20) as async_client:
Expand Down
2 changes: 1 addition & 1 deletion app/endpoints/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
async def get_chunks(
collection: UUID,
document: UUID,
limit: Optional[int] = Query(default=10, ge=1, le=10),
limit: Optional[int] = Query(default=10, ge=1, le=100),
offset: Optional[UUID] = None,
user: User = Security(check_api_key),
) -> Chunks:
Expand Down
9 changes: 0 additions & 9 deletions app/endpoints/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from app.schemas.security import User
from app.utils.lifespan import clients
from app.utils.security import check_api_key
from app.utils.variables import LANGUAGE_MODEL_TYPE
from app.utils.exceptions import WrongModelTypeException, ContextLengthExceededException

router = APIRouter()

Expand All @@ -20,13 +18,6 @@ async def completions(request: CompletionRequest, user: User = Security(check_ap

request = dict(request)
client = clients.models[request["model"]]

if client.type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()

if not client.check_context_length(model=request["model"], messages=request["messages"]):
raise ContextLengthExceededException()

url = f"{client.base_url}completions"
headers = {"Authorization": f"Bearer {client.api_key}"}

Expand Down
2 changes: 1 addition & 1 deletion app/endpoints/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@router.get("/documents/{collection}")
async def get_documents(
collection: UUID, limit: Optional[int] = Query(default=10, ge=1, le=10), offset: Optional[UUID] = None, user: User = Security(check_api_key)
collection: UUID, limit: Optional[int] = Query(default=10, ge=1, le=100), offset: Optional[UUID] = None, user: User = Security(check_api_key)
) -> Documents:
"""
Get all documents ID from a collection.
Expand Down
1 change: 1 addition & 0 deletions app/endpoints/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ async def embeddings(request: EmbeddingsRequest, user: User = Security(check_api
raise e

data = response.json()

return Embeddings(**data)
12 changes: 8 additions & 4 deletions app/helpers/_fileuploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

from fastapi import UploadFile

from app.schemas.data import ParserOutput
from app.helpers.chunkers import *
from app.helpers.parsers import HTMLParser, JSONParser, PDFParser
from app.schemas.chunks import Chunk
from app.schemas.data import ParserOutput
from app.schemas.security import User
from app.utils.exceptions import InvalidJSONFormatException, NoChunksToUpsertException, ParsingFileFailedException, UnsupportedFileTypeException
from app.utils.variables import (
CHUNKERS,
DEFAULT_CHUNKER,
HTML_TYPE,
JSON_TYPE,
PDF_TYPE,
)

from ._vectorstore import VectorStore
from app.utils.exceptions import ParsingFileFailedException, UnsupportedFileTypeException, NoChunksToUpsertException


class FileUploader:
Expand All @@ -37,7 +38,7 @@ def parse(self, file: UploadFile) -> List[ParserOutput]:
if file_type not in self.TYPE_DICT.keys():
raise UnsupportedFileTypeException()

file_type = self.TYPE_DICT[file.filename.split(".")[-1]]
file_type = self.TYPE_DICT[file_type]

if file_type == PDF_TYPE:
parser = PDFParser(collection_id=self.collection_id)
Expand All @@ -51,7 +52,10 @@ def parse(self, file: UploadFile) -> List[ParserOutput]:
try:
output = parser.parse(file=file)
except Exception as e:
raise ParsingFileFailedException()
if isinstance(e, InvalidJSONFormatException):
raise e
else:
raise ParsingFileFailedException()

return output

Expand Down
6 changes: 4 additions & 2 deletions app/helpers/_modelclients.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ def get_models_list(self, *args, **kwargs):
return Models(data=data)


def check_context_length(self, model: str, messages: List[Dict[str, str]], add_special_tokens: bool = True):
def check_context_length(self, messages: List[Dict[str, str]], add_special_tokens: bool = True):
headers = {"Authorization": f"Bearer {self.api_key}"}
prompt = "\n".join([message["role"] + ": " + message["content"] for message in messages])
data = {"model": model, "prompt": prompt, "add_special_tokens": add_special_tokens}
data = {"model": self.id, "prompt": prompt, "add_special_tokens": add_special_tokens}

response = requests.post(str(self.base_url).replace("/v1/", "/tokenize"), json=data, headers=headers)
response.raise_for_status()

return response.json()["count"] <= response.json()["max_model_len"]


Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(self, type=Literal[EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE], sea
response = self.models.list()
model = response.data[0]
self.id = model.id
self.max_context_length = model.max_model_len

if self.type == EMBEDDINGS_MODEL_TYPE:
response = self.embeddings.create(model=self.id, input="hello world")
Expand Down
18 changes: 17 additions & 1 deletion app/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
ChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam,
)
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from app.utils.exceptions import ContextLengthExceededException, MaxTokensExceededException, WrongModelTypeException
from app.utils.lifespan import clients
from app.utils.variables import LANGUAGE_MODEL_TYPE


class ChatCompletionRequest(BaseModel):
Expand All @@ -34,6 +38,18 @@ class ChatCompletionRequest(BaseModel):
class ConfigDict:
extra = "allow"

@model_validator(mode="before")
def validate_model(cls, value):
if clients.models[value["model"]].type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()

if not clients.models[value["model"]].check_context_length(messages=value["messages"]):
raise ContextLengthExceededException()

if value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length:
raise MaxTokensExceededException()
return value


class ChatCompletion(ChatCompletion):
pass
Expand Down
17 changes: 16 additions & 1 deletion app/schemas/completions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Dict, Iterable, List, Optional, Union

from openai.types import Completion
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from app.utils.lifespan import clients
from app.utils.variables import LANGUAGE_MODEL_TYPE
from app.utils.exceptions import WrongModelTypeException, ContextLengthExceededException, MaxTokensExceededException


class CompletionRequest(BaseModel):
Expand All @@ -23,6 +27,17 @@ class CompletionRequest(BaseModel):
top_p: Optional[float] = 1.0
user: Optional[str] = None

@model_validator(mode="before")
def validate_model(cls, value):
if clients.models[value["model"]].type != LANGUAGE_MODEL_TYPE:
raise WrongModelTypeException()

if not clients.models[value["model"]].check_context_length(messages=value["messages"]):
raise ContextLengthExceededException()

if value["max_tokens"] is not None and value["max_tokens"] > clients.models[value["model"]].max_context_length:
raise MaxTokensExceededException()


class Completions(Completion):
pass
3 changes: 0 additions & 3 deletions app/schemas/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ class ChunkerArgs(BaseModel):
# additional arguments
chunk_min_size: int = Field(0)

class Config:
extra = "allow"


class Chunker(BaseModel):
name: Optional[Literal[*CHUNKERS]] = Field(None)
Expand Down
16 changes: 15 additions & 1 deletion app/tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,27 @@ def test_chat_completions_unknown_params(self, args, session_user, setup):
def test_chat_completions_max_tokens_too_large(self, args, session_user, setup):
MODEL_ID, MAX_MODEL_LEN = setup

prompt = "test"
params = {
"model": MODEL_ID,
"messages": [{"role": "user", "content": prompt}],
"stream": True,
"n": 1,
"max_tokens": 1000000,
}
response = session_user.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 422, f"error: retrieve chat completions ({response.status_code})"

def test_chat_completions_context_too_large(self, args, session_user, setup):
MODEL_ID, MAX_MODEL_LEN = setup

prompt = "test" * (MAX_MODEL_LEN + 100)
params = {
"model": MODEL_ID,
"messages": [{"role": "user", "content": prompt}],
"stream": True,
"n": 1,
"max_tokens": 10,
"max_tokens": 1000000,
}
response = session_user.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 413, f"error: retrieve chat completions ({response.status_code})"
8 changes: 4 additions & 4 deletions app/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_create_public_collection_with_user(self, args, session_user, setup):

params = {"name": PUBLIC_COLLECTION_NAME, "model": EMBEDDINGS_MODEL_ID, "type": PUBLIC_COLLECTION_TYPE}
response = session_user.post(f"{args["base_url"]}/collections", json=params)
assert response.status_code == 400
assert response.status_code == 422

def test_create_public_collection_with_admin(self, args, session_admin, setup):
PUBLIC_COLLECTION_NAME, _, _, _, EMBEDDINGS_MODEL_ID, _ = setup
Expand All @@ -63,7 +63,7 @@ def test_create_private_collection_with_language_model_with_user(self, args, ses

params = {"name": PRIVATE_COLLECTION_NAME, "model": LANGUAGE_MODEL_ID, "type": PRIVATE_COLLECTION_TYPE}
response = session_user.post(f"{args["base_url"]}/collections", json=params)
assert response.status_code == 400
assert response.status_code == 422

def test_create_private_collection_with_unknown_model_with_user(self, args, session_user, setup):
_, PRIVATE_COLLECTION_NAME, _, _, _, _ = setup
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_delete_public_collection_with_user(self, args, session_user, setup):
response = session_user.get(f"{args["base_url"]}/collections")
collection_id = [collection["id"] for collection in response.json()["data"] if collection["name"] == PUBLIC_COLLECTION_NAME][0]
response = session_user.delete(f"{args["base_url"]}/collections/{collection_id}")
assert response.status_code == 400
assert response.status_code == 422

def test_delete_public_collection_with_admin(self, args, session_admin, setup):
PUBLIC_COLLECTION_NAME, _, _, _, _, _ = setup
Expand All @@ -132,7 +132,7 @@ def test_create_internet_collection_with_user(self, args, session_user, setup):

params = {"name": INTERNET_COLLECTION_ID, "model": EMBEDDINGS_MODEL_ID, "type": PUBLIC_COLLECTION_TYPE}
response = session_user.post(f"{args["base_url"]}/collections", json=params)
assert response.status_code == 400
assert response.status_code == 422

def test_create_collection_with_empty_name(self, args, session_user, setup):
_, _, _, _, EMBEDDINGS_MODEL_ID, _ = setup
Expand Down
4 changes: 2 additions & 2 deletions app/tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_upload_json_file_wrong_format(self, args, session_user, setup):
files = {"file": (os.path.basename(file_path), open(file_path, "rb"), "application/json")}
data = {"request": '{"collection": "%s"}' % PRIVATE_COLLECTION_ID}
response = session_user.post(f"{args["base_url"]}/files", data=data, files=files)
assert response.status_code == 400, f"error: upload file ({response.status_code} - {response.text})"
assert response.status_code == 422, f"error: upload file ({response.status_code} - {response.text})"

def test_upload_too_large_file(self, args, session_user, setup):
PRIVATE_COLLECTION_ID, _ = setup
Expand All @@ -111,4 +111,4 @@ def test_upload_in_public_collection_with_user(self, args, session_user, setup):
files = {"file": (os.path.basename(file_path), open(file_path, "rb"), "application/pdf")}
data = {"request": '{"collection": "%s"}' % PUBLIC_COLLECTION_ID}
response = session_user.post(f"{args["base_url"]}/files", data=data, files=files)
assert response.status_code == 400, f"error: upload file ({response.status_code} - {response.text})"
assert response.status_code == 422, f"error: upload file ({response.status_code} - {response.text})"
45 changes: 25 additions & 20 deletions app/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,6 @@ def __init__(self, detail: str = "Parsing file failed."):
super().__init__(status_code=400, detail=detail)


class WrongModelTypeException(HTTPException):
def __init__(self, detail: str = "Wrong model type."):
super().__init__(status_code=400, detail=detail)


class WrongCollectionTypeException(HTTPException):
def __init__(self, detail: str = "Wrong collection type."):
super().__init__(status_code=400, detail=detail)


class DifferentCollectionsModelsException(HTTPException):
def __init__(self, detail: str = "Different collections models."):
super().__init__(status_code=400, detail=detail)


class UnsupportedFileTypeException(HTTPException):
def __init__(self, detail: str = "Unsupported file type."):
super().__init__(status_code=400, detail=detail)


class NoChunksToUpsertException(HTTPException):
def __init__(self, detail: str = "No chunks to upsert."):
super().__init__(status_code=400, detail=detail)
Expand Down Expand Up @@ -69,3 +49,28 @@ def __init__(self, detail: str = "File size limit exceeded."):
class InvalidJSONFormatException(HTTPException):
def __init__(self, detail: str = "Invalid JSON file format."):
super().__init__(status_code=422, detail=detail)


class WrongModelTypeException(HTTPException):
def __init__(self, detail: str = "Wrong model type."):
super().__init__(status_code=422, detail=detail)


class MaxTokensExceededException(HTTPException):
def __init__(self, detail: str = "Max tokens exceeded."):
super().__init__(status_code=422, detail=detail)


class WrongCollectionTypeException(HTTPException):
def __init__(self, detail: str = "Wrong collection type."):
super().__init__(status_code=422, detail=detail)


class DifferentCollectionsModelsException(HTTPException):
def __init__(self, detail: str = "Different collections models."):
super().__init__(status_code=422, detail=detail)


class UnsupportedFileTypeException(HTTPException):
def __init__(self, detail: str = "Unsupported file type."):
super().__init__(status_code=422, detail=detail)
8 changes: 5 additions & 3 deletions ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@
else:
messages = st.session_state.messages
sources = []

stream = openai_client.chat.completions.create(stream=True, messages=messages, **params["sampling_params"])
response = st.write_stream(stream)
try:
stream = openai_client.chat.completions.create(stream=True, messages=messages, **params["sampling_params"])
response = st.write_stream(stream)
except Exception as e:
st.error(e)
if sources:
st.multiselect(options=sources, label="Sources", key="sources_tmp", default=sources)

Expand Down
2 changes: 1 addition & 1 deletion ui/pages/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
with col1:
with st.expander("Upload a file", icon="📑"):
collection = st.selectbox(
"Select collection to delete", [f"{collection["name"]} - {collection["id"]}" for collection in collections], key="upload_file_selectbox"
"Select a collection", [f"{collection["name"]} - {collection["id"]}" for collection in collections], key="upload_file_selectbox"
)
collection_id = collection.split(" - ")[-1]
file_to_upload = st.file_uploader("File", type=["pdf", "html", "json"])
Expand Down

0 comments on commit 693b97c

Please sign in to comment.