Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute similarity map in vespa #1490

Merged
merged 3 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 84 additions & 30 deletions visual-retrieval-colpali/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from io import BytesIO
from typing import Union, Tuple, List, Dict, Any
import matplotlib
import matplotlib.cm as cm
import re
import io

import json
import time

from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.utils.torch_utils import get_torch_device
from einops import rearrange
Expand Down Expand Up @@ -114,6 +118,7 @@ def gen_similarity_maps(
query_embs: torch.Tensor,
token_idx_map: dict,
images: List[Union[Path, str]],
vespa_sim_maps: List[str],
) -> List[Dict[str, str]]:
"""
Generate similarity maps for the given images and query, and return base64-encoded blended images.
Expand All @@ -131,8 +136,8 @@ def gen_similarity_maps(
Returns:
List[Dict[str, str]]: A list where each item is a dictionary mapping tokens to base64-encoded blended images.
"""
import numpy as np
import matplotlib.cm as cm

start = time.perf_counter()

# Prepare the colormap once to avoid recomputation
colormap = cm.get_cmap("viridis")
Expand All @@ -158,39 +163,74 @@ def gen_similarity_maps(
original_sizes.append(img_pil.size) # (width, height)
processed_images.append(img_pil)

# Preprocess inputs
input_image_processed = processor.process_images(processed_images).to(device)

# Forward passes
with torch.no_grad():
output_image = model.forward(**input_image_processed)

# Remove the special tokens from the output
output_image = output_image[:, : processor.image_seq_length, :]
# If similarity maps are provided, use them instead of computing them
if vespa_sim_maps:
print("Using provided similarity maps")
# A sim map looks like this:
# "similarities": [
# {
# "address": {
# "patch": "0",
# "querytoken": "0"
# },
# "value": 1.2599412202835083
# },
# ... and so on.
# Now turn these into a tensor of same shape as previous similarity map
vespa_sim_map_tensor = torch.zeros(
(len(vespa_sim_maps), query_embs.size(dim=1), vit_config.n_patch_per_dim, vit_config.n_patch_per_dim)
)
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
for cell in vespa_sim_map["similarities"]["cells"]:
patch = int(cell["address"]["patch"])
if patch >= processor.image_seq_length:
continue
query_token = int(cell["address"]["querytoken"])
value = cell["value"]
vespa_sim_map_tensor[idx, int(query_token), int(patch) // vit_config.n_patch_per_dim, int(patch) % vit_config.n_patch_per_dim] = value

# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(vespa_sim_map_tensor)
else:
# Preprocess inputs
print("Computing similarity maps")
start2 = time.perf_counter()
input_image_processed = processor.process_images(processed_images).to(device)

# Forward passes
with torch.no_grad():
output_image = model.forward(**input_image_processed)

# Remove the special tokens from the output
output_image = output_image[:, : processor.image_seq_length, :]

# Rearrange the output image tensor to represent the 2D grid of patches
output_image = rearrange(
output_image,
"b (h w) c -> b h w c",
h=vit_config.n_patch_per_dim,
w=vit_config.n_patch_per_dim,
)

# Rearrange the output image tensor to represent the 2D grid of patches
output_image = rearrange(
output_image,
"b (h w) c -> b h w c",
h=vit_config.n_patch_per_dim,
w=vit_config.n_patch_per_dim,
)
# Ensure query_embs has batch dimension
if query_embs.dim() == 2:
query_embs = query_embs.unsqueeze(0).to(device)
else:
query_embs = query_embs.to(device)

# Ensure query_embs has batch dimension
if query_embs.dim() == 2:
query_embs = query_embs.unsqueeze(0).to(device)
else:
query_embs = query_embs.to(device)
# Compute the similarity map
similarity_map = torch.einsum(
"bnk,bhwk->bnhw", query_embs, output_image
) # Shape: (batch_size, query_tokens, h, w)

# Compute the similarity map
similarity_map = torch.einsum(
"bnk,bhwk->bnhw", query_embs, output_image
) # Shape: (batch_size, query_tokens, h, w)
end2 = time.perf_counter()
print(f"Similarity map computation took: {end2 - start2} s")

# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map)
# Normalize the similarity map per query token
similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map)

# Collect the blended images
start3 = time.perf_counter()
results = []
for idx, img in enumerate(original_images):
original_size = original_sizes[idx] # (width, height)
Expand Down Expand Up @@ -248,6 +288,9 @@ def gen_similarity_maps(
# Store the base64-encoded image
result_per_image[token] = blended_img_base64
results.append(result_per_image)
end3 = time.perf_counter()
print(f"Collecting blended images took: {end3 - start3} s")
print(f"Total heatmap generation took: {end3 - start} s")
return results


Expand Down Expand Up @@ -285,9 +328,11 @@ async def query_vespa_default(
) -> dict:
async with app.asyncio(connections=1, total_timeout=120) as session:
query_embedding = format_q_embs(q_emb)

start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": "select id,title,url,full_image,page_number,snippet,text from pdf_page where userQuery();",
"yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
"ranking": "default",
"query": query,
"timeout": timeout,
Expand All @@ -298,6 +343,9 @@ async def query_vespa_default(
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s")
open("response.json", "w").write(json.dumps(response.json))
return format_query_results(query, response)


Expand Down Expand Up @@ -447,10 +495,14 @@ def add_sim_maps_to_result(
) -> Dict[str, Any]:
vit_config = load_vit_config(model)
imgs: List[str] = []
vespa_sim_maps: List[str] = []
for single_result in result["root"]["children"]:
img = single_result["fields"]["full_image"]
if img:
imgs.append(img)
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
if vespa_sim_map:
vespa_sim_maps.append(vespa_sim_map)
sim_map_imgs = gen_similarity_maps(
model=model,
processor=processor,
Expand All @@ -460,6 +512,7 @@ def add_sim_maps_to_result(
query_embs=q_embs,
token_idx_map=token_to_idx,
images=imgs,
vespa_sim_maps=vespa_sim_maps
)
for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs):
for token, sim_mapb64 in sim_map_dict.items():
Expand Down Expand Up @@ -491,6 +544,7 @@ def add_sim_maps_to_result(
query_embs=q_embs,
token_idx_map=token_to_idx,
images=[image_filepath],
vespa_sim_maps=None,
)
for fig_token in figs_images:
for token, (fig, ax) in fig_token.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ schema pdf_page {

}
}
function similarities() {
expression {
sum(
query(qt) * unpack_bits(attribute(embedding)), v
)
}
}
function bm25_score() {
expression {
bm25(title) + bm25(text)
Expand All @@ -108,6 +115,7 @@ schema pdf_page {
max_sim
}
}
summary-features: similarities
}
rank-profile retrieval-and-rerank {
inputs {
Expand Down
5 changes: 4 additions & 1 deletion visual-retrieval-colpali/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fasthtml.common import *
from shad4fast import *
from vespa.application import Vespa
import time

from backend.colpali import (
get_result_from_query,
Expand Down Expand Up @@ -103,6 +104,7 @@ async def get(request, query: str, nn: bool = True):
processor = manager.processor
q_embs, token_to_idx = get_query_embeddings_and_token_map(processor, model, query)

start = time.perf_counter()
# Fetch real search results from Vespa
result = await get_result_from_query(
app=vespa_app,
Expand All @@ -113,13 +115,14 @@ async def get(request, query: str, nn: bool = True):
token_to_idx=token_to_idx,
ranking=ranking_value,
)
end = time.perf_counter()
print(f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds")
# Start generating the similarity map in the background
asyncio.create_task(
generate_similarity_map(
model, processor, query, q_embs, token_to_idx, result, query_id
)
)
print("Search results fetched")
search_results = (
result["root"]["children"]
if "root" in result and "children" in result["root"]
Expand Down
Loading