Skip to content

Commit

Permalink
feat: remove tools (#24)
Browse files Browse the repository at this point in the history
Co-authored-by: leoguillaume <[email protected]>
  • Loading branch information
leoguillaume and leoguillaumegouv authored Sep 24, 2024
1 parent 3208baa commit 68a87cf
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 312 deletions.
37 changes: 1 addition & 36 deletions app/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import Union

from fastapi import APIRouter, HTTPException, Security
Expand All @@ -7,16 +6,12 @@

from app.schemas.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionRequest
from app.schemas.config import LANGUAGE_MODEL_TYPE
from app.tools import *
from app.tools import __all__ as tools_list
from app.utils.config import LOGGER
from app.utils.lifespan import clients
from app.utils.security import check_api_key

router = APIRouter()


# @TODO: remove tooling from here
@router.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest, user: str = Security(check_api_key)) -> Union[ChatCompletion, ChatCompletionChunk]:
"""Completion API similar to OpenAI's API.
Expand All @@ -34,50 +29,20 @@ async def chat_completions(request: ChatCompletionRequest, user: str = Security(
if not client.check_context_length(model=request["model"], messages=request["messages"]):
raise HTTPException(status_code=400, detail="Context length too large")

# tool call
metadata = list()
tools = request.get("tools")
if tools:
for tool in tools:
if tool["function"]["name"] not in tools_list:
raise HTTPException(status_code=404, detail="Tool not found")
func = globals()[tool["function"]["name"]](clients=clients)
params = request | tool["function"]["parameters"]
params["user"] = user
LOGGER.debug(f"params: {params}")
try:
tool_output = await func.get_prompt(**params)
except Exception as e:
raise HTTPException(status_code=400, detail=f"tool error {e}")
metadata.append({tool["function"]["name"]: tool_output.model_dump()})
request["messages"] = [{"role": "user", "content": tool_output.prompt}]
request.pop("tools")

if not client.check_context_length(model=request["model"], messages=request["messages"]):
raise HTTPException(status_code=400, detail="Context length too large after tool call")

# non stream case
if not request["stream"]:
async_client = httpx.AsyncClient(timeout=20)
response = await async_client.request(method="POST", url=url, headers=headers, json=request)
print(response.text)
response.raise_for_status()
data = response.json()
data["metadata"] = metadata
return ChatCompletion(**data)

# stream case
async def forward_stream(url: str, headers: dict, request: dict):
async with httpx.AsyncClient(timeout=20) as async_client:
async with async_client.stream(method="POST", url=url, headers=headers, json=request) as response:
i = 0
async for chunk in response.aiter_raw():
if i == 0:
chunks = chunk.decode("utf-8").split("\n\n")
chunk = json.loads(chunks[0].lstrip("data: "))
chunk["metadata"] = metadata
chunks[0] = f"data: {json.dumps(chunk)}"
chunk = "\n\n".join(chunks).encode("utf-8")
i = 1
yield chunk

return StreamingResponse(forward_stream(url, headers, request), media_type="text/event-stream")
1 change: 0 additions & 1 deletion app/endpoints/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
router = APIRouter()


# @TODO: remove get one collection and a /collections/search to similarity search (remove /tools)
@router.get("/collections/{collection}")
@router.get("/collections")
async def get_collections(collection: Optional[str] = None, user: str = Security(check_api_key)) -> Union[Collection, Collections]:
Expand Down
5 changes: 1 addition & 4 deletions app/helpers/_universalparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parse_and_chunk(self, file_path: str, chunk_size: int, chunk_overlap: int, c
pass

if file_type not in self.SUPPORTED_FILE_TYPES:
file_type = "unknown"
raise NotImplementedError(f"Unsupported input file format ({file_path}): {file_type}")

if file_type == self.PDF_TYPE:
chunks = self._pdf_to_chunks(file_path=file_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap, chunk_min_size=chunk_min_size)
Expand All @@ -71,9 +71,6 @@ def parse_and_chunk(self, file_path: str, chunk_size: int, chunk_overlap: int, c
elif file_type == self.JSON_TYPE:
chunks = self._json_to_chunks(file_path=file_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap, chunk_min_size=chunk_min_size)

else:
raise NotImplementedError(f"Unsupported input file format ({file_path}): {file_type}")

return chunks

## Parser and chunking functions
Expand Down
3 changes: 1 addition & 2 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import FastAPI, Response, Security

from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search, tools
from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search
from app.utils.config import APP_CONTACT_EMAIL, APP_CONTACT_URL, APP_DESCRIPTION, APP_VERSION
from app.utils.lifespan import lifespan
from app.utils.security import check_api_key
Expand Down Expand Up @@ -32,4 +32,3 @@ def health(user: str = Security(check_api_key)):
app.include_router(chunks.router, tags=["Chunks"], prefix="/v1")
app.include_router(files.router, tags=["Files"], prefix="/v1")
app.include_router(search.router, tags=["Search"], prefix="/v1")
app.include_router(tools.router, tags=["Tools"], prefix="/v1")
16 changes: 11 additions & 5 deletions app/schemas/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Literal, Optional, Union
from typing import List, Literal, Optional, Union

from openai.types.chat import (
ChatCompletion,
Expand All @@ -9,10 +9,9 @@
)
from pydantic import BaseModel, Field

from app.schemas.tools import ToolOutput


class ChatCompletionRequest(BaseModel):
# See https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py
messages: List[ChatCompletionMessageParam]
model: str
stream: Optional[Literal[True, False]] = False
Expand All @@ -27,11 +26,18 @@ class ChatCompletionRequest(BaseModel):
stop: Union[Optional[str], List[str]] = Field(default_factory=list)
tool_choice: Optional[Union[Literal["none"], ChatCompletionToolChoiceOptionParam]] = "none"
tools: List[ChatCompletionToolParam] = None
user: Optional[str] = None
best_of: Optional[int] = None
top_k: int = -1
min_p: float = 0.0

class ConfigDict:
extra = "allow"


class ChatCompletion(ChatCompletion):
metadata: Optional[List[Dict[str, ToolOutput]]] = []
pass


class ChatCompletionChunk(ChatCompletionChunk):
metadata: Optional[List[Dict[str, ToolOutput]]] = []
pass
79 changes: 49 additions & 30 deletions app/tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import logging

import pytest

from app.schemas.chat import ChatCompletion # , ChatCompletionChunk
from app.schemas.chat import ChatCompletion, ChatCompletionChunk
from app.schemas.config import LANGUAGE_MODEL_TYPE


Expand Down Expand Up @@ -69,32 +70,50 @@ def test_chat_completions_streamed_response(self, args, session):
response = session.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code})"

# def test_chat_completions_streamed_response_schemas(self, args, session):
# """Test the GET /chat/completions response schemas."""
# # retrieve model
# response = session.get(f"{args['base_url']}/models")
# assert response.status_code == 200, f"error: retrieve models ({response.status_code})"
# response_json = response.json()
# model = [
# model["id"] for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE
# ][0]
# logging.debug(f"model: {model}")

# params = {
# "model": model,
# "messages": [{"role": "user", "content": "Hello, how are you?"}],
# "stream": True,
# "n": 1,
# "max_tokens": 100,
# }
# response = session.post(f"{args['base_url']}/chat/completions", json=params)

# chunks = []
# for line in response.iter_lines():
# if line:
# chunk = json.loads(line.decode("utf-8").split("data: ")[1])
# chunks.append(chunk)
# chat_completion_chunk = ChatCompletionChunk(**chunk)
# assert isinstance(
# chat_completion_chunk, ChatCompletionChunk
# ), f"error: retrieve chat completions chunk {chunk}"
def test_chat_completions_streamed_response_schemas(self, args, session):
"""Test the GET /chat/completions response schemas."""
# retrieve model
response = session.get(f"{args['base_url']}/models")
assert response.status_code == 200, f"error: retrieve models ({response.status_code})"
response_json = response.json()
model = [model["id"] for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE][0]
logging.debug(f"model: {model}")

params = {
"model": model,
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"stream": True,
"n": 1,
"max_tokens": 100,
}
response = session.post(f"{args['base_url']}/chat/completions", json=params)

chunks = []
for line in response.iter_lines():
if line:
chunk = line.decode("utf-8").split("data: ")[1]
if chunk == "[DONE]":
break
chunk = json.loads(chunk)
chat_completion_chunk = ChatCompletionChunk(**chunk)
assert isinstance(chat_completion_chunk, ChatCompletionChunk), f"error: retrieve chat completions chunk {chunk}"

def test_chat_completions_unknown_params(self, args, session):
"""Test the GET /chat/completions response status code."""
# retrieve model
response = session.get(f"{args['base_url']}/models")
assert response.status_code == 200, f"error: retrieve models ({response.status_code})"
response_json = response.json()
model = [model["id"] for model in response_json["data"] if model["type"] == LANGUAGE_MODEL_TYPE][0]
logging.debug(f"model: {model}")

params = {
"model": model,
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"stream": True,
"n": 1,
"max_tokens": 10,
"min_tokens": 3, # unknown param in ChatCompletionRequest schema
}
response = session.post(f"{args['base_url']}/chat/completions", json=params)
assert response.status_code == 200, f"error: retrieve chat completions ({response.status_code})"
35 changes: 13 additions & 22 deletions app/tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os

import pytest
import wget

from app.schemas.collections import Collection, Collections
from app.schemas.config import EMBEDDINGS_MODEL_TYPE, LANGUAGE_MODEL_TYPE, METADATA_COLLECTION, PRIVATE_COLLECTION_TYPE
Expand All @@ -14,19 +13,12 @@
@pytest.fixture(scope="function")
def setup(args, session):
COLLECTION = "pytest"
FILE_NAME = "pytest.pdf"
FILE_URL = "http://www.legifrance.gouv.fr/download/file/rxcTl0H4YnnzLkMLiP4x15qORfLSKk_h8QsSb2xnJ8Y=/JOE_TEXTE"

if not os.path.exists(FILE_NAME):
wget.download(FILE_URL, out=FILE_NAME)
FILE_PATH = "app/tests/pytest.pdf"

USER = encode_string(input=args["api_key"])
logging.info(f"test user ID: {USER}")

yield COLLECTION, FILE_NAME, USER

if os.path.exists(FILE_NAME):
os.remove(FILE_NAME)
yield COLLECTION, FILE_PATH, USER


@pytest.mark.usefixtures("args", "session", "setup")
Expand All @@ -52,21 +44,20 @@ def test_get_collections(self, args, session, setup):
assert isinstance(collections, Collections)
assert all(isinstance(collection, Collection) for collection in collections.data)

if COLLECTION in [collection.id for collection in collections.data]:
response = session.delete(f"{args['base_url']}/collections", params={"collection": COLLECTION}, timeout=10)
assert response.status_code == 204, f"error: delete collection ({response.status_code})"
response = session.delete(f"{args['base_url']}/collections", params={"collection": COLLECTION}, timeout=10)
assert response.status_code == 204 or response.status_code == 404, f"error: delete collection ({response.status_code})"

assert METADATA_COLLECTION not in [
collection.id for collection in collections.data
], f"{METADATA_COLLECTION} metadata collection is displayed in collections"

def test_upload_file(self, args, session, setup):
COLLECTION, FILE_NAME, _ = setup
COLLECTION, FILE_PATH, _ = setup
models = self.get_models(args, session)
EMBEDDINGS_MODEL = [model for model in models.data if model.type == EMBEDDINGS_MODEL_TYPE][0].id

params = {"embeddings_model": EMBEDDINGS_MODEL, "collection": COLLECTION}
files = {"files": (os.path.basename(FILE_NAME), open(FILE_NAME, "rb"), "application/pdf")}
files = {"files": (os.path.basename(FILE_PATH), open(FILE_PATH, "rb"), "application/pdf")}
response = session.post(f"{args['base_url']}/files", params=params, files=files, timeout=30)

assert response.status_code == 200, f"error: upload file ({response.status_code} - {response.text})"
Expand All @@ -87,7 +78,7 @@ def test_upload_file(self, args, session, setup):
files["data"] = [File(**file) for file in files["data"]]
assert len(files["data"]) == 1, f"error: number of files ({len(files)})"
files = Files(**files)
assert files.data[0].file_name == FILE_NAME, f"error: file name ({files.data[0].file_name})"
assert files.data[0].file_name == os.path.basename(FILE_PATH), f"error: file name ({files.data[0].file_name})"
assert files.data[0].id == file_id, f"error: file id ({files.data[0].id})"

def test_collection_creation(self, args, session, setup):
Expand All @@ -107,37 +98,37 @@ def test_collection_creation(self, args, session, setup):
assert collection.user == USER, f"{COLLECTION} collection user is not {USER}"

def test_upload_with_wrong_model(self, args, session, setup):
COLLECTION, FILE_NAME, _ = setup
COLLECTION, FILE_PATH, _ = setup
models = self.get_models(args, session)

language_model = [model for model in models.data if model.type == LANGUAGE_MODEL_TYPE][0].id
params = {
"embeddings_model": language_model,
"collection": COLLECTION,
}
files = {"files": (os.path.basename(FILE_NAME), open(FILE_NAME, "rb"), "application/pdf")}
files = {"files": (os.path.basename(FILE_PATH), open(FILE_PATH, "rb"), "application/pdf")}
response = session.post(f"{args['base_url']}/files", params=params, files=files, timeout=10)

assert response.status_code == 400, f"error: upload file ({response.status_code} - {response.text})"

def test_upload_with_non_existing_model(self, args, session, setup):
COLLECTION, FILE_NAME, _ = setup
COLLECTION, FILE_PATH, _ = setup

params = {
"embeddings_model": "test",
"collection": COLLECTION,
}
files = {"files": (os.path.basename(FILE_NAME), open(FILE_NAME, "rb"), "application/pdf")}
files = {"files": (os.path.basename(FILE_PATH), open(FILE_PATH, "rb"), "application/pdf")}
response = session.post(f"{args['base_url']}/files", params=params, files=files, timeout=10)

assert response.status_code == 404, f"error: upload file ({response.status_code} - {response.text})"

def test_delete_file(self, args, session, setup):
COLLECTION, FILE_NAME, _ = setup
COLLECTION, FILE_PATH, _ = setup
models = self.get_models(args, session)
EMBEDDINGS_MODEL = [model for model in models.data if model.type == EMBEDDINGS_MODEL_TYPE][0].id
params = {"embeddings_model": EMBEDDINGS_MODEL, "collection": COLLECTION}
files = {"files": (os.path.basename(FILE_NAME), open(FILE_NAME, "rb"), "application/pdf")}
files = {"files": (os.path.basename(FILE_PATH), open(FILE_PATH, "rb"), "application/pdf")}
response = session.post(f"{args['base_url']}/files", params=params, files=files, timeout=30)

uploads = response.json()
Expand Down
Loading

0 comments on commit 68a87cf

Please sign in to comment.