diff --git a/visual-retrieval-colpali/backend/colpali.py b/visual-retrieval-colpali/backend/colpali.py index febbca2c7..b102d99f0 100644 --- a/visual-retrieval-colpali/backend/colpali.py +++ b/visual-retrieval-colpali/backend/colpali.py @@ -9,9 +9,13 @@ from io import BytesIO from typing import Union, Tuple, List, Dict, Any import matplotlib +import matplotlib.cm as cm import re import io +import json +import time + from colpali_engine.models import ColPali, ColPaliProcessor from colpali_engine.utils.torch_utils import get_torch_device from einops import rearrange @@ -114,6 +118,7 @@ def gen_similarity_maps( query_embs: torch.Tensor, token_idx_map: dict, images: List[Union[Path, str]], + vespa_sim_maps: List[str], ) -> List[Dict[str, str]]: """ Generate similarity maps for the given images and query, and return base64-encoded blended images. @@ -131,8 +136,8 @@ def gen_similarity_maps( Returns: List[Dict[str, str]]: A list where each item is a dictionary mapping tokens to base64-encoded blended images. """ - import numpy as np - import matplotlib.cm as cm + + start = time.perf_counter() # Prepare the colormap once to avoid recomputation colormap = cm.get_cmap("viridis") @@ -158,39 +163,74 @@ def gen_similarity_maps( original_sizes.append(img_pil.size) # (width, height) processed_images.append(img_pil) - # Preprocess inputs - input_image_processed = processor.process_images(processed_images).to(device) - - # Forward passes - with torch.no_grad(): - output_image = model.forward(**input_image_processed) - - # Remove the special tokens from the output - output_image = output_image[:, : processor.image_seq_length, :] + # If similarity maps are provided, use them instead of computing them + if vespa_sim_maps: + print("Using provided similarity maps") + # A sim map looks like this: + # "similarities": [ + # { + # "address": { + # "patch": "0", + # "querytoken": "0" + # }, + # "value": 1.2599412202835083 + # }, + # ... and so on. + # Now turn these into a tensor of same shape as previous similarity map + vespa_sim_map_tensor = torch.zeros( + (len(vespa_sim_maps), query_embs.size(dim=1), vit_config.n_patch_per_dim, vit_config.n_patch_per_dim) + ) + for idx, vespa_sim_map in enumerate(vespa_sim_maps): + for cell in vespa_sim_map["similarities"]["cells"]: + patch = int(cell["address"]["patch"]) + if patch >= processor.image_seq_length: + continue + query_token = int(cell["address"]["querytoken"]) + value = cell["value"] + vespa_sim_map_tensor[idx, int(query_token), int(patch) // vit_config.n_patch_per_dim, int(patch) % vit_config.n_patch_per_dim] = value + + # Normalize the similarity map per query token + similarity_map_normalized = normalize_similarity_map_per_query_token(vespa_sim_map_tensor) + else: + # Preprocess inputs + print("Computing similarity maps") + start2 = time.perf_counter() + input_image_processed = processor.process_images(processed_images).to(device) + + # Forward passes + with torch.no_grad(): + output_image = model.forward(**input_image_processed) + + # Remove the special tokens from the output + output_image = output_image[:, : processor.image_seq_length, :] + + # Rearrange the output image tensor to represent the 2D grid of patches + output_image = rearrange( + output_image, + "b (h w) c -> b h w c", + h=vit_config.n_patch_per_dim, + w=vit_config.n_patch_per_dim, + ) - # Rearrange the output image tensor to represent the 2D grid of patches - output_image = rearrange( - output_image, - "b (h w) c -> b h w c", - h=vit_config.n_patch_per_dim, - w=vit_config.n_patch_per_dim, - ) + # Ensure query_embs has batch dimension + if query_embs.dim() == 2: + query_embs = query_embs.unsqueeze(0).to(device) + else: + query_embs = query_embs.to(device) - # Ensure query_embs has batch dimension - if query_embs.dim() == 2: - query_embs = query_embs.unsqueeze(0).to(device) - else: - query_embs = query_embs.to(device) + # Compute the similarity map + similarity_map = torch.einsum( + "bnk,bhwk->bnhw", query_embs, output_image + ) # Shape: (batch_size, query_tokens, h, w) - # Compute the similarity map - similarity_map = torch.einsum( - "bnk,bhwk->bnhw", query_embs, output_image - ) # Shape: (batch_size, query_tokens, h, w) + end2 = time.perf_counter() + print(f"Similarity map computation took: {end2 - start2} s") - # Normalize the similarity map per query token - similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map) + # Normalize the similarity map per query token + similarity_map_normalized = normalize_similarity_map_per_query_token(similarity_map) # Collect the blended images + start3 = time.perf_counter() results = [] for idx, img in enumerate(original_images): original_size = original_sizes[idx] # (width, height) @@ -248,6 +288,9 @@ def gen_similarity_maps( # Store the base64-encoded image result_per_image[token] = blended_img_base64 results.append(result_per_image) + end3 = time.perf_counter() + print(f"Collecting blended images took: {end3 - start3} s") + print(f"Total heatmap generation took: {end3 - start} s") return results @@ -285,9 +328,11 @@ async def query_vespa_default( ) -> dict: async with app.asyncio(connections=1, total_timeout=120) as session: query_embedding = format_q_embs(q_emb) + + start = time.perf_counter() response: VespaQueryResponse = await session.query( body={ - "yql": "select id,title,url,full_image,page_number,snippet,text from pdf_page where userQuery();", + "yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();", "ranking": "default", "query": query, "timeout": timeout, @@ -298,6 +343,9 @@ async def query_vespa_default( }, ) assert response.is_successful(), response.json + stop = time.perf_counter() + print(f"Query time + data transfer took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s") + open("response.json", "w").write(json.dumps(response.json)) return format_query_results(query, response) @@ -447,10 +495,14 @@ def add_sim_maps_to_result( ) -> Dict[str, Any]: vit_config = load_vit_config(model) imgs: List[str] = [] + vespa_sim_maps: List[str] = [] for single_result in result["root"]["children"]: img = single_result["fields"]["full_image"] if img: imgs.append(img) + vespa_sim_map = single_result["fields"].get("summaryfeatures", None) + if vespa_sim_map: + vespa_sim_maps.append(vespa_sim_map) sim_map_imgs = gen_similarity_maps( model=model, processor=processor, @@ -460,6 +512,7 @@ def add_sim_maps_to_result( query_embs=q_embs, token_idx_map=token_to_idx, images=imgs, + vespa_sim_maps=vespa_sim_maps ) for single_result, sim_map_dict in zip(result["root"]["children"], sim_map_imgs): for token, sim_mapb64 in sim_map_dict.items(): @@ -491,6 +544,7 @@ def add_sim_maps_to_result( query_embs=q_embs, token_idx_map=token_to_idx, images=[image_filepath], + vespa_sim_maps=None, ) for fig_token in figs_images: for token, (fig, ax) in fig_token.items(): diff --git a/visual-retrieval-colpali/colpali-with-snippets/schemas/pdf_page.sd b/visual-retrieval-colpali/colpali-with-snippets/schemas/pdf_page.sd index 66ad729da..9563de847 100644 --- a/visual-retrieval-colpali/colpali-with-snippets/schemas/pdf_page.sd +++ b/visual-retrieval-colpali/colpali-with-snippets/schemas/pdf_page.sd @@ -92,6 +92,13 @@ schema pdf_page { } } + function similarities() { + expression { + sum( + query(qt) * unpack_bits(attribute(embedding)), v + ) + } + } function bm25_score() { expression { bm25(title) + bm25(text) @@ -108,6 +115,7 @@ schema pdf_page { max_sim } } + summary-features: similarities } rank-profile retrieval-and-rerank { inputs { diff --git a/visual-retrieval-colpali/main.py b/visual-retrieval-colpali/main.py index a149276a0..29b402a0e 100644 --- a/visual-retrieval-colpali/main.py +++ b/visual-retrieval-colpali/main.py @@ -5,6 +5,7 @@ from fasthtml.common import * from shad4fast import * from vespa.application import Vespa +import time from backend.colpali import ( get_result_from_query, @@ -103,6 +104,7 @@ async def get(request, query: str, nn: bool = True): processor = manager.processor q_embs, token_to_idx = get_query_embeddings_and_token_map(processor, model, query) + start = time.perf_counter() # Fetch real search results from Vespa result = await get_result_from_query( app=vespa_app, @@ -113,13 +115,14 @@ async def get(request, query: str, nn: bool = True): token_to_idx=token_to_idx, ranking=ranking_value, ) + end = time.perf_counter() + print(f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds") # Start generating the similarity map in the background asyncio.create_task( generate_similarity_map( model, processor, query, q_embs, token_to_idx, result, query_id ) ) - print("Search results fetched") search_results = ( result["root"]["children"] if "root" in result and "children" in result["root"]