Skip to content

Commit

Permalink
add pure bm25 ranking option
Browse files Browse the repository at this point in the history
  • Loading branch information
andreer committed Oct 10, 2024
1 parent ae6a096 commit 6a17020
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 18 deletions.
33 changes: 29 additions & 4 deletions visual-retrieval-colpali/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def query_vespa_default(
query_embedding = format_q_embs(q_emb)
response: VespaQueryResponse = await session.query(
body={
"yql": "select id,title,url,full_image,page_number,text from pdf_page where userQuery();",
"yql": "select id,title,url,full_image,page_number,snippet,text from pdf_page where userQuery();",
"ranking": "default",
"query": query,
"timeout": timeout,
Expand All @@ -300,6 +300,27 @@ async def query_vespa_default(
assert response.is_successful(), response.json
return format_query_results(query, response)

async def query_vespa_bm25(
app: Vespa,
query: str,
hits: int = 3,
timeout: str = "10s",
**kwargs,
) -> dict:
async with app.asyncio(connections=1, total_timeout=120) as session:
response: VespaQueryResponse = await session.query(
body={
"yql": "select id,title,url,full_image,page_number,snippet,text from pdf_page where userQuery();",
"ranking": "bm25",
"query": query,
"timeout": timeout,
"hits": hits,
"presentation.timing": True,
**kwargs,
},
)
assert response.is_successful(), response.json
return format_query_results(query, response)

def float_to_binary_embedding(float_query_embedding: dict) -> dict:
binary_query_embeddings = {}
Expand Down Expand Up @@ -387,17 +408,21 @@ async def get_result_from_query(
processor: ColPaliProcessor,
model: ColPali,
query: str,
nn: bool = False,
ranking: str,
gen_sim_map: bool = True,
) -> Dict[str, Any]:
# Get the query embeddings and token map
print(query)
q_embs, token_to_idx = get_query_embeddings_and_token_map(processor, model, query)
print(token_to_idx)
if nn:
if ranking == "nn+colpali":
result = await query_vespa_nearest_neighbor(app, query, q_embs)
else:
elif ranking == "bm25+colpali":
result = await query_vespa_default(app, query, q_embs)
elif ranking == "bm25":
result = await query_vespa_bm25(app, query)
else:
raise ValueError(f"Unsupported ranking: {ranking}")
# Print score, title id, and text of the results
for idx, child in enumerate(result["root"]["children"]):
print(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ schema pdf_page {
fieldset image {
fields: image
}

rank-profile bm25 {
first-phase {
expression: bm25(title) + bm25(text)
}
}

rank-profile default {
inputs {
query(qt) tensor<float>(querytoken{}, v[128])
Expand Down
9 changes: 9 additions & 0 deletions visual-retrieval-colpali/colpali-with-snippets/services.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
<token id="colpali-read" />
</client>
</clients>
<config name="container.qr-searchers">
<tag>
<bold>
<open>&lt;strong&gt;</open>
<close>&lt;/strong&gt;</close>
</bold>
<separator>...</separator>
</tag>
</config>
</container>
<content id="colpali_content" version="1.0">
<redundancy>1</redundancy>
Expand Down
18 changes: 9 additions & 9 deletions visual-retrieval-colpali/frontend/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from urllib.parse import quote_plus

from fasthtml.components import H1, H2, Div, Form, Img, P, Span
from fasthtml.components import H1, H2, Div, Form, Img, P, Span, NotStr
from fasthtml.xtend import A, Script
from lucide_fasthtml import Lucide
from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem
Expand Down Expand Up @@ -77,18 +77,18 @@ def SearchBox(with_border=False, query_value="", ranking_value="option1"):
Span("Ranking by:", cls="text-muted-foreground text-xs font-semibold"),
RadioGroup(
Div(
RadioGroupItem(value="colpali", id="colpali"),
Label("colpali", htmlFor="colpali"),
RadioGroupItem(value="nn+colpali", id="nn+colpali"),
Label("nn+colpali", htmlFor="nn+colpali"),
cls="flex items-center space-x-2",
),
Div(
RadioGroupItem(value="bm25", id="bm25"),
Label("bm25", htmlFor="bm25"),
RadioGroupItem(value="bm25+colpali", id="bm25+colpali"),
Label("bm25+colpali", htmlFor="bm25+colpali"),
cls="flex items-center space-x-2",
),
Div(
RadioGroupItem(value="option3", id="option3"),
Label("option3", htmlFor="option3"),
RadioGroupItem(value="bm25", id="bm25"),
Label("bm25", htmlFor="bm25"),
cls="flex items-center space-x-2",
),
name="ranking",
Expand Down Expand Up @@ -276,8 +276,8 @@ def SearchResult(results=[], show_sim_map=False):
H2(fields["title"], cls="text-xl font-semibold"),
P("Page " + str(fields["page_number"]), cls="text-muted-foreground"),
P("Relevance score: " + str(result["relevance"]), cls="text-muted-foreground"),
P(fields["snippet"], cls="text-muted-foreground"),
P(fields["text"], cls="text-muted-foreground"),
P(NotStr(fields["snippet"]), cls="text-muted-foreground"),
P(NotStr(fields["text"]), cls="text-muted-foreground"),
cls="text-sm grid gap-y-4",
),
cls="bg-background px-3 py-5 hidden md:block",
Expand Down
7 changes: 2 additions & 5 deletions visual-retrieval-colpali/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,11 @@ def get(request, query: str, nn: bool = True, sim_map: bool = True):
return RedirectResponse("/search")

# Extract ranking option from the request
ranking_value = request.query_params.get("ranking", "option1")
ranking_value = request.query_params.get("ranking")
print(
f"/fetch_results: Fetching results for query: {query}, ranking: {ranking_value}"
)

if "bm25" in ranking_value:
nn = False

# Fetch model and processor
manager = ModelManager.get_instance()
model = manager.model
Expand All @@ -117,7 +114,7 @@ def get(request, query: str, nn: bool = True, sim_map: bool = True):
processor=processor,
model=model,
query=query,
nn=nn,
ranking=ranking_value,
gen_sim_map=sim_map,
)
)
Expand Down

0 comments on commit 6a17020

Please sign in to comment.