Skip to content

Commit

Permalink
feat: add query expansion using an LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
ydennisy committed Jul 14, 2024
1 parent 480a7f7 commit 9315a18
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 93 deletions.
9 changes: 5 additions & 4 deletions backend/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ supabase = "==2.5.1"
python-dotenv = "==1.0.1"
trafilatura = "==1.11.0"
lxml = "==5.2.2"
langchain = "==0.1.16"
langchain = "==0.2.7"
requests = "==2.32.3"
openai = "==1.35.13"
llama-index-readers-file = "==0.1.19"
llama-index-readers-file = "==0.1.30"
httpx = "==0.27.0"
pymupdf = "==1.24.2"
pymupdf = "==1.24.7"
pydantic = {extras = ["email"], version = "==2.7.4"}
python-multipart = "==0.0.9"
umap-learn = "*"
umap-learn = "==0.5.6"
instructor = "==1.3.4"

[dev-packages]

Expand Down
213 changes: 148 additions & 65 deletions backend/Pipfile.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backend/app/domain/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(
self.summary = summary
self.embedding = None
self.chunks: list[TextNodeChunk] = []
self.create_title_if_missing()

def create_chunks(self, chunker: NodeChunker) -> None:
self.chunks = chunker.chunk(self.id, self.text)
Expand All @@ -43,6 +42,7 @@ async def create_embeddings(self, embedder: NodeEmbedder) -> None:
chunk.embedding = embedding

# TODO: this can be done using an LLM.
# NOTE: this method is switched off for now.
def create_title_if_missing(self) -> None:
if not self.title:
words = self.text.split()[:10]
Expand Down
43 changes: 39 additions & 4 deletions backend/app/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
from enum import Enum
from typing import List, Generator, Any
from typing import Generator, Any

import instructor
from openai import OpenAI
from pydantic import BaseModel, Field

client = OpenAI()

Expand All @@ -28,7 +31,7 @@ class Models(Enum):
)


def format_chunks(chunks: List[dict]) -> str:
def format_chunks(chunks: list[dict]) -> str:
result = ""
for chunk in chunks:
chunk.pop("id")
Expand All @@ -38,7 +41,7 @@ def format_chunks(chunks: List[dict]) -> str:
return result


def answer_with_context(chunks: List[dict], question: str) -> Generator[str, Any, Any]:
def answer_with_context(chunks: list[dict], question: str) -> Generator[str, Any, Any]:
formatted_chunks = format_chunks(chunks)
messages = [
{
Expand Down Expand Up @@ -95,7 +98,7 @@ def summarise_concept(texts: list[str]) -> str:
DO NOT EXPLAIN OR REPEAT EACH TEXT."""

result = client.chat.completions.create(
model="gpt-4o",
model=Models.GPT_4o_LATEST.value,
messages=[
{
"role": "user",
Expand All @@ -104,3 +107,35 @@ def summarise_concept(texts: list[str]) -> str:
],
)
return result.choices[0].message.content


class ExpandedQueries(BaseModel):
queries: list[str] = Field(
..., max_items=5, description="List of expanded search queries"
)


def expand_search_query(query: str) -> list[str]:
client = instructor.from_openai(OpenAI())
prompt = f"""Given the user's search query, generate up to 5 expanded queries for improved information retrieval.
Each query should be related to the original but explore different aspects and use alternative terminology, especially when technical concepts have different names.
DO NOT include the original query in the expanded queries.
DO NOT simple rephrase the original query. The goal is to explore different aspects of the topic and related topics and ensure a breadth of keywords.
Original query: {query}
Provide the expanded queries as a list."""

messages = [
{
"role": "user",
"content": prompt,
}
]
result = client.chat.completions.create(
messages=messages,
response_model=ExpandedQueries,
model=Models.GPT_4o_LATEST.value,
temperature=1.2,
)
return result.queries
19 changes: 17 additions & 2 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import BaseModel, Field, ConfigDict, EmailStr

from app.db import DB
from app.llm import answer_with_context
from app.llm import answer_with_context, expand_search_query
from app.utils import (
NodeEmbedder,
get_current_user,
Expand Down Expand Up @@ -89,7 +89,7 @@ async def get_explore_route(user=Depends(get_current_user)):

@app.get("/api/search")
async def get_search_route(
q: str, mode: Literal["hybrid", "dense"], user=Depends(get_current_user)
q: str, mode: Literal["hybrid", "dense", "llm"], user=Depends(get_current_user)
):
user_id = user.id
query_emb = await NodeEmbedder.embed(q)
Expand All @@ -99,6 +99,21 @@ async def get_search_route(
elif mode == "dense":
pages = db.search_text_nodes(query_emb, user_id=user_id, threshold=0.1)
return pages
elif mode == "llm":
results = []
pages = db.search_text_nodes(query_emb, user_id=user_id, threshold=0.1)
results.extend(pages)

queries = expand_search_query(q)
for query in queries:
query_emb = await NodeEmbedder.embed(query)
pages = db.search_text_nodes(query_emb, user_id=user_id, threshold=0.1)
results.extend(pages)
results = list({v["id"]: v for v in results}.values())
# rank results by score key
results.sort(key=lambda x: x["score"], reverse=True)
print(results)
return results
else:
raise HTTPException(400)

Expand Down
8 changes: 3 additions & 5 deletions backend/app/utils/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from app.domain import TextNodeChunk

CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
CHUNK_SIZE = 2048
CHUNK_OVERLAP = 256

text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " ", ""],
Expand All @@ -23,6 +23,4 @@ class NodeChunker:
def chunk(node_id: str, node_text: str) -> list[TextNodeChunk]:
chunks = text_splitter.create_documents(texts=[node_text])
chunks = [c.page_content for c in chunks]
return [
TextNodeChunk(text=chunk, text_node_id=node_id) for chunk in chunks
]
return [TextNodeChunk(text=chunk, text_node_id=node_id) for chunk in chunks]
6 changes: 3 additions & 3 deletions backend/app/utils/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ async def _fetch_json(self, url) -> dict:
json = r.json()
return json

def _extract_content_from_html(self, html):
extracted = trafilatura.bare_extraction(html)
def _extract_content_from_html(self, html: str, url: str):
extracted = trafilatura.bare_extraction(html, url=url, with_metadata=True)
return extracted

async def process(self, url: str) -> URLProcessingResult:
Expand All @@ -88,7 +88,7 @@ async def process(
) -> Union[URLProcessingResult, URLProcessingFailure]:
try:
html = await self._crawl_url(url)
extracted = self._extract_content_from_html(html)
extracted = self._extract_content_from_html(html, url)
return URLProcessingResult(
url=url,
title=extracted["title"],
Expand Down
9 changes: 0 additions & 9 deletions backend/run.py

This file was deleted.

0 comments on commit 9315a18

Please sign in to comment.