Skip to content

Commit

Permalink
feat: add search endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Sep 17, 2024
1 parent e3f4545 commit 12d4172
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 38 deletions.
28 changes: 28 additions & 0 deletions app/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from fastapi import APIRouter, Security

from app.helpers import VectorStore
from app.schemas.chunks import Chunks
from app.schemas.search import SearchRequest
from app.utils.lifespan import clients
from app.utils.security import check_api_key

router = APIRouter()


@router.post("/search")
async def search(request: SearchRequest, user: str = Security(check_api_key)) -> Chunks:
"""
Similarity search for chunks in the vector store.
Parameters:
request (SearchRequest): The search request.
user (str): The user.
Returns:
Chunks: The chunks.
"""

vectorstore = VectorStore(clients=clients, user=user)
data = vectorstore.search(prompt=request.prompt, collection_names=request.collections, k=request.k, score_threshold=request.score_threshold)

return Chunks(data=data)
44 changes: 19 additions & 25 deletions app/helpers/_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from qdrant_client.http.models import Distance, FieldCondition, Filter, MatchAny, PointIdsList, PointStruct, VectorParams

from app.schemas.chunks import Chunk
from app.schemas.collections import CollectionMetadata, Document
from langchain.docstore.document import Document
from app.schemas.collections import CollectionMetadata
from app.schemas.config import EMBEDDINGS_MODEL_TYPE, METADATA_COLLECTION, PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE


Expand All @@ -23,7 +24,7 @@ def from_documents(self, documents: List[Document], model: str, collection_name:
Add documents to a collection.
Parameters:
documents (List[Document]): A list of Document objects to add to the collection.
documents (List[Document]): A list of Langchain Document objects to add to the collection.
model (str): The model to use for embeddings.
collection_name (str): The name of the collection to add the documents to.
"""
Expand Down Expand Up @@ -63,11 +64,11 @@ def search(
k: Optional[int] = 4,
score_threshold: Optional[float] = None,
filter: Optional[Filter] = None,
) -> List[Document]:
) -> List[Chunk]:
response = self.models[model].embeddings.create(input=[prompt], model=model)
vector = response.data[0].embedding

documents = []
chunks = []
collections = self.get_collection_metadata(collection_names=collection_names)
for collection in collections:
if collection.model != model:
Expand All @@ -81,21 +82,20 @@ def search(
with_payload=True,
query_filter=filter,
)
for i, result in enumerate(results):
results[i] = result.model_dump()
results[i]["collection"] = collection.name

documents.extend(results)
chunks.extend(results)

# sort by similarity score and get top k
documents = sorted(documents, key=lambda x: x.score, reverse=True)[:k]
documents = [
Document(
id=document.id,
page_content=document.payload["page_content"],
metadata=document.payload["metadata"],
)
for document in documents
chunks = sorted(chunks, key=lambda x: x["score"], reverse=True)[:k]
chunks = [
Chunk(id=chunk["id"], collection=chunk["collection"], content=chunk["payload"]["page_content"], metadata=chunk["payload"]["metadata"])
for chunk in chunks
]

return documents
return chunks

def get_collection_metadata(self, collection_names: List[str] = [], type: str = "all", errors: str = "raise") -> List[CollectionMetadata]:
"""
Expand Down Expand Up @@ -258,15 +258,9 @@ def get_chunks(self, collection_name: str, filter: Optional[Filter] = None) -> L
scroll_filter=filter,
limit=100, # @TODO: add pagination
)[0]
data = list()
for chunk in chunks:
data.append(
Chunk(
collection=collection_name,
id=chunk.id,
metadata=chunk.payload["metadata"],
content=chunk.payload["page_content"],
)
)
chunks = [
Chunk(collection=collection_name, id=chunk.id, metadata=chunk.payload["metadata"], content=chunk.payload["page_content"])
for chunk in chunks
]

return data
return chunks
3 changes: 2 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import FastAPI, Response, Security

from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, tools
from app.endpoints import chat, chunks, collections, completions, embeddings, files, models, search, tools
from app.utils.config import APP_CONTACT_EMAIL, APP_CONTACT_URL, APP_DESCRIPTION, APP_VERSION
from app.utils.lifespan import lifespan
from app.utils.security import check_api_key
Expand Down Expand Up @@ -31,4 +31,5 @@ def health(user: str = Security(check_api_key)):
app.include_router(collections.router, tags=["Collections"], prefix="/v1")
app.include_router(chunks.router, tags=["Chunks"], prefix="/v1")
app.include_router(files.router, tags=["Files"], prefix="/v1")
app.include_router(search.router, tags=["Search"], prefix="/v1")
app.include_router(tools.router, tags=["Tools"], prefix="/v1")
11 changes: 2 additions & 9 deletions app/schemas/collections.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Literal, List, Optional, Dict, Any
from uuid import UUID
from typing import List, Literal, Optional

from pydantic import BaseModel

from app.schemas.config import PUBLIC_COLLECTION_TYPE, PRIVATE_COLLECTION_TYPE
from app.schemas.config import PRIVATE_COLLECTION_TYPE, PUBLIC_COLLECTION_TYPE


class Collection(BaseModel):
Expand All @@ -22,12 +21,6 @@ class Collections(BaseModel):
data: List[Collection]


class Document(BaseModel):
id: UUID
page_content: str
metadata: Dict[str, Any]


class CollectionMetadata(BaseModel):
id: str
name: Optional[str] = None
Expand Down
10 changes: 10 additions & 0 deletions app/schemas/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import List, Optional

from pydantic import BaseModel


class SearchRequest(BaseModel):
prompt: str
collections: List[str]
k: int
score_threshold: Optional[float] = None
6 changes: 3 additions & 3 deletions app/tools/_baserag.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ async def get_prompt(
vectorstore = VectorStore(clients=self.clients, user=request["user"])
prompt = request["messages"][-1]["content"]

documents = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k)
chunks = vectorstore.search(model=embeddings_model, prompt=prompt, collection_names=collections, k=k)

metadata = {"chunks": [document.metadata for document in documents]}
documents = "\n\n".join([document.page_content for document in documents])
metadata = {"chunks": [chunk.metadata for chunk in chunks]}
documents = "\n\n".join([chunk.content for chunk in chunks])
prompt = prompt_template.format(documents=documents, prompt=prompt)

return ToolOutput(prompt=prompt, metadata=metadata)

0 comments on commit 12d4172

Please sign in to comment.