diff --git a/visual-retrieval-colpali/.gitignore b/visual-retrieval-colpali/.gitignore index abe93bb20..cfdf56d51 100644 --- a/visual-retrieval-colpali/.gitignore +++ b/visual-retrieval-colpali/.gitignore @@ -9,7 +9,7 @@ template/ output/ pdfs/ colpalidemo/ -static/full_images/ -static/sim_maps/ +src/static/full_images/ +src/static/sim_maps/ embeddings/ hf_dataset/ \ No newline at end of file diff --git a/visual-retrieval-colpali/README.md b/visual-retrieval-colpali/README.md index 4f23c4fcb..eda14ba8c 100644 --- a/visual-retrieval-colpali/README.md +++ b/visual-retrieval-colpali/README.md @@ -1,22 +1,3 @@ ---- -title: ColPali 🤝 Vespa - Visual Retrieval -short_description: Visual Retrieval with ColPali and Vespa -emoji: 👀 -colorFrom: purple -colorTo: blue -sdk: gradio -sdk_version: 4.44.0 -app_file: main.py -pinned: false -license: apache-2.0 -models: - - vidore/colpaligemma-3b-pt-448-base - - vidore/colpali-v1.2 -preload_from_hub: - - vidore/colpaligemma-3b-pt-448-base config.json,model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,model.safetensors.index.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa - - vidore/colpali-v1.2 adapter_config.json,adapter_model.safetensors,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 9912ce6f8a462d8cf2269f5606eabbd2784e764f ---- - @@ -75,7 +56,7 @@ Skip to [Installing dependencies using `uv`](#installing-dependencies-using-uv) You can install the dependencies with `pip`: ```bash -pip install -r requirements.txt +pip install -r src/requirements.txt ``` ### Installing dependencies using `uv` @@ -107,7 +88,7 @@ uv sync --extra dev ## Running the application locally -To run the application locally, you can run: +To run the application locally, you can change into the `src` directory and run: ```bash python main.py @@ -122,27 +103,39 @@ This will start a local server, and you can access the application at `http://lo Before a deploy, make sure to run this to compile the `uv` lock file to `requirements.txt` if you have made changes to the dependencies: ```bash -uv pip compile pyproject.toml -o requirements.txt +uv pip compile pyproject.toml -o src/requirements.txt ``` This will make sure that the dependencies in your `pyproject.toml` are compiled to the `requirements.txt` file, which is used by the huggingface space. ### Deploying to huggingface +Note that you need to set `HF_TOKEN` environment variable first. +This is personal, and must be created at [huggingface](https://huggingface.co/settings/tokens). +Make sure the token has `write` access. +Be aware that this will not delete existing files, only modify or add, +see [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/upload#upload-from-the-cli) for more +information. + +#### Update your space configuration + +The `src/README.md` file contains the configuration for the space. +Feel free to update this file to match your own configuration - name, description, etc. + +Note that we can actually use the `gradio` SDK of spaces, to serve FastHTML apps as well, as long as we serve the app on port `7860`. +See [Custom python spaces](https://huggingface.co/docs/hub/en/spaces-sdks-python) for more information. + +#### Upload the files + To deploy, run (Replace `vespa-engine/colpali-vespa-visual-retrieval` with your own huggingface user/repo name, does not need to exist beforehand) ```bash -huggingface-cli upload vespa-engine/colpali-vespa-visual-retrieval . . --repo-type=space +huggingface-cli upload vespa-engine/colpali-vespa-visual-retrieval src . --repo-type=space ``` -Note that you need to set `HF_TOKEN` environment variable first. -This is personal, and must be created at [huggingface](https://huggingface.co/settings/tokens). -Make sure the token has `write` access. -Be aware that this will not delete existing files, only modify or add, -see [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/upload#upload-from-the-cli) for more -information. +Note that we upload only the `src` directory. ## Development diff --git a/visual-retrieval-colpali/backend/colpali.py b/visual-retrieval-colpali/backend/colpali.py deleted file mode 100644 index ee83aeaab..000000000 --- a/visual-retrieval-colpali/backend/colpali.py +++ /dev/null @@ -1,325 +0,0 @@ -#!/usr/bin/env python3 - -import torch -from PIL import Image -import numpy as np -from typing import cast, Generator -from pathlib import Path -import base64 -from io import BytesIO -from typing import Union, Tuple, List -import matplotlib -import matplotlib.cm as cm -import re -import io - -import time -import backend.testquery as testquery - -from colpali_engine.models import ColPali, ColPaliProcessor -from colpali_engine.utils.torch_utils import get_torch_device -from einops import rearrange -from vidore_benchmark.interpretability.torch_utils import ( - normalize_similarity_map_per_query_token, -) -from vidore_benchmark.interpretability.vit_configs import VIT_CONFIG - -matplotlib.use("Agg") -# Prepare the colormap once to avoid recomputation -colormap = cm.get_cmap("viridis") - -COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base" - - -def load_model() -> Tuple[ColPali, ColPaliProcessor]: - model_name = "vidore/colpali-v1.2" - - device = get_torch_device("auto") - print(f"Using device: {device}") - - # Load the model - model = cast( - ColPali, - ColPali.from_pretrained( - model_name, - torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, - device_map=device, - ), - ).eval() - - # Load the processor - processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) - return model, processor, device - - -def load_vit_config(model): - # Load the ViT config - print(f"VIT config: {VIT_CONFIG}") - vit_config = VIT_CONFIG[COLPALI_GEMMA_MODEL_NAME] - return vit_config - - -def gen_similarity_maps( - model: ColPali, - processor: ColPaliProcessor, - device, - query: str, - query_embs: torch.Tensor, - token_idx_map: dict, - images: List[Union[Path, str]], - vespa_sim_maps: List[str], -) -> Generator[Tuple[int, str, str], None, None]: - """ - Generate similarity maps for the given images and query, and return base64-encoded blended images. - - Args: - model (ColPali): The model used for generating embeddings. - processor (ColPaliProcessor): Processor for images and text. - device: Device to run the computations on. - vit_config: Configuration for the Vision Transformer. - query (str): The query string. - query_embs (torch.Tensor): Query embeddings. - token_idx_map (dict): Mapping from indices to tokens. - images (List[Union[Path, str]]): List of image paths or base64-encoded strings. - vespa_sim_maps (List[str]): List of Vespa similarity maps. - - Yields: - Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image. - - """ - vit_config = load_vit_config(model) - # Process images and store original images and sizes - processed_images = [] - original_images = [] - original_sizes = [] - for img in images: - if isinstance(img, Path): - try: - img_pil = Image.open(img).convert("RGB") - except Exception as e: - raise ValueError(f"Failed to open image from path: {e}") - elif isinstance(img, str): - try: - img_pil = Image.open(BytesIO(base64.b64decode(img))).convert("RGB") - except Exception as e: - raise ValueError(f"Failed to open image from base64 string: {e}") - else: - raise ValueError(f"Unsupported image type: {type(img)}") - original_images.append(img_pil.copy()) - original_sizes.append(img_pil.size) # (width, height) - processed_images.append(img_pil) - - # If similarity maps are provided, use them instead of computing them - if vespa_sim_maps: - print("Using provided similarity maps") - # A sim map looks like this: - # "quantized": [ - # { - # "address": { - # "patch": "0", - # "querytoken": "0" - # }, - # "value": 12, # score in range [-128, 127] - # }, - # ... and so on. - # Now turn these into a tensor of same shape as previous similarity map - vespa_sim_map_tensor = torch.zeros( - ( - len(vespa_sim_maps), - query_embs.size(dim=1), - vit_config.n_patch_per_dim, - vit_config.n_patch_per_dim, - ) - ) - for idx, vespa_sim_map in enumerate(vespa_sim_maps): - for cell in vespa_sim_map["quantized"]["cells"]: - patch = int(cell["address"]["patch"]) - # if dummy model then just use 1024 as the image_seq_length - - if hasattr(processor, "image_seq_length"): - image_seq_length = processor.image_seq_length - else: - image_seq_length = 1024 - - if patch >= image_seq_length: - continue - query_token = int(cell["address"]["querytoken"]) - value = cell["value"] - vespa_sim_map_tensor[ - idx, - int(query_token), - int(patch) // vit_config.n_patch_per_dim, - int(patch) % vit_config.n_patch_per_dim, - ] = value - - # Normalize the similarity map per query token - similarity_map_normalized = normalize_similarity_map_per_query_token( - vespa_sim_map_tensor - ) - else: - # Preprocess inputs - print("Computing similarity maps") - start2 = time.perf_counter() - input_image_processed = processor.process_images(processed_images).to(device) - - # Forward passes - with torch.no_grad(): - output_image = model.forward(**input_image_processed) - - # Remove the special tokens from the output - output_image = output_image[:, : processor.image_seq_length, :] - - # Rearrange the output image tensor to represent the 2D grid of patches - output_image = rearrange( - output_image, - "b (h w) c -> b h w c", - h=vit_config.n_patch_per_dim, - w=vit_config.n_patch_per_dim, - ) - - # Ensure query_embs has batch dimension - if query_embs.dim() == 2: - query_embs = query_embs.unsqueeze(0).to(device) - else: - query_embs = query_embs.to(device) - - # Compute the similarity map - similarity_map = torch.einsum( - "bnk,bhwk->bnhw", query_embs, output_image - ) # Shape: (batch_size, query_tokens, h, w) - - end2 = time.perf_counter() - print(f"Similarity map computation took: {end2 - start2} s") - - # Normalize the similarity map per query token - similarity_map_normalized = normalize_similarity_map_per_query_token( - similarity_map - ) - - # Collect the blended images - start3 = time.perf_counter() - for idx, img in enumerate(original_images): - SCALING_FACTOR = 8 - sim_map_resolution = ( - max(32, int(original_sizes[idx][0] / SCALING_FACTOR)), - max(32, int(original_sizes[idx][1] / SCALING_FACTOR)), - ) - - result_per_image = {} - for token_idx, token in token_idx_map.items(): - if should_filter_token(token): - continue - - # Get the similarity map for this image and the selected token - sim_map = similarity_map_normalized[idx, token_idx, :, :] # Shape: (h, w) - - # Move the similarity map to CPU, convert to float (as BFloat16 not supported by Numpy) and convert to NumPy array - sim_map_np = sim_map.cpu().float().numpy() - - # Resize the similarity map to the original image size - sim_map_img = Image.fromarray(sim_map_np) - sim_map_resized = sim_map_img.resize( - sim_map_resolution, resample=Image.BICUBIC - ) - - # Convert the resized similarity map to a NumPy array - sim_map_resized_np = np.array(sim_map_resized, dtype=np.float32) - - # Normalize the similarity map to range [0, 1] - sim_map_min = sim_map_resized_np.min() - sim_map_max = sim_map_resized_np.max() - if sim_map_max - sim_map_min > 1e-6: - sim_map_normalized = (sim_map_resized_np - sim_map_min) / ( - sim_map_max - sim_map_min - ) - else: - sim_map_normalized = np.zeros_like(sim_map_resized_np) - - # Apply a colormap to the normalized similarity map - heatmap = colormap(sim_map_normalized) # Returns an RGBA array - - # Convert the heatmap to a PIL Image - heatmap_uint8 = (heatmap * 255).astype(np.uint8) - heatmap_img = Image.fromarray(heatmap_uint8) - heatmap_img_rgba = heatmap_img.convert("RGBA") - - # Save the image to a BytesIO buffer - buffer = io.BytesIO() - heatmap_img_rgba.save(buffer, format="PNG") - buffer.seek(0) - - # Encode the image to base64 - blended_img_base64 = base64.b64encode(buffer.read()).decode("utf-8") - - # Store the base64-encoded image - result_per_image[token] = blended_img_base64 - yield idx, token, token_idx, blended_img_base64 - end3 = time.perf_counter() - print(f"Blending images took: {end3 - start3} s") - - -def get_query_embeddings_and_token_map( - processor, model, query -) -> Tuple[torch.Tensor, dict]: - if model is None: # use static test query data (saves time when testing) - return testquery.q_embs, testquery.idx_to_token - - start_time = time.perf_counter() - inputs = processor.process_queries([query]).to(model.device) - with torch.no_grad(): - embeddings_query = model(**inputs) - q_emb = embeddings_query.to("cpu")[0] # Extract the single embedding - # Use this cell output to choose a token using its index - query_tokens = processor.tokenizer.tokenize(processor.decode(inputs.input_ids[0])) - # reverse key, values in dictionary - print(query_tokens) - idx_to_token = {idx: val for idx, val in enumerate(query_tokens)} - end_time = time.perf_counter() - print(f"Query inference took: {end_time - start_time} s") - return q_emb, idx_to_token - - -def should_filter_token(token: str) -> bool: - """ - Determines whether a token should be filtered out based on predefined patterns. - - The function filters out tokens that: - - Start with '<' (e.g., '') - - Consist entirely of whitespace - - Are purely punctuation (excluding tokens that contain digits or start with '▁') - - Start with an underscore '_' - - Exactly match the word 'Question' - - Are exactly the single character '▁' - - Output of test: - - Token: '2' | False - Token: '0' | False - Token: '2' | False - Token: '3' | False - Token: '▁2' | False - Token: '▁hi' | False - Token: 'norwegian' | False - Token: 'unlisted' | False - Token: '' | True - Token: 'Question' | True - Token: ':' | True - Token: '' | True - Token: '\n' | True - Token: '▁' | True - Token: '?' | True - Token: ')' | True - Token: '%' | True - Token: '/)' | True - - Tokens that do not match these patterns (e.g., 'norwegian', 'unlisted') are not filtered out. - - Args: - token (str): The token to evaluate. - - Returns: - bool: True if the token should be filtered out, False otherwise. - """ - pattern = re.compile(r"^<.*$|^\s+$|^(?!.*\d)(?!▁)[^\w\s]+$|^_.*$|^Question$|^▁$") - - return bool(pattern.match(token)) diff --git a/visual-retrieval-colpali/pyproject.toml b/visual-retrieval-colpali/pyproject.toml index c9ba1d0be..c7237121a 100644 --- a/visual-retrieval-colpali/pyproject.toml +++ b/visual-retrieval-colpali/pyproject.toml @@ -37,4 +37,82 @@ feed = [ "beautifulsoup4", "pdf2image", "google-generativeai" -] \ No newline at end of file +] +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 88 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py38" + +[tool.ruff.lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = false + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" \ No newline at end of file diff --git a/visual-retrieval-colpali/ruff.toml b/visual-retrieval-colpali/ruff.toml deleted file mode 100644 index d28e492a4..000000000 --- a/visual-retrieval-colpali/ruff.toml +++ /dev/null @@ -1,77 +0,0 @@ -# Exclude a variety of commonly ignored directories. -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".git-rewrite", - ".hg", - ".ipynb_checkpoints", - ".mypy_cache", - ".nox", - ".pants.d", - ".pyenv", - ".pytest_cache", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - ".vscode", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "site-packages", - "venv", -] - -# Same as Black. -line-length = 88 -indent-width = 4 - -# Assume Python 3.8 -target-version = "py38" - -[lint] -# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. -# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or -# McCabe complexity (`C901`) by default. -select = ["E4", "E7", "E9", "F"] -ignore = [] - -# Allow fix for all enabled rules (when `--fix`) is provided. -fixable = ["ALL"] -unfixable = [] - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -[format] -# Like Black, use double quotes for strings. -quote-style = "double" - -# Like Black, indent with spaces, rather than tabs. -indent-style = "space" - -# Like Black, respect magic trailing commas. -skip-magic-trailing-comma = false - -# Like Black, automatically detect the appropriate line ending. -line-ending = "auto" - -# Enable auto-formatting of code examples in docstrings. Markdown, -# reStructuredText code/literal blocks and doctests are all supported. -# -# This is currently disabled by default, but it is planned for this -# to be opt-out in the future. -docstring-code-format = false - -# Set the line length limit used when formatting code snippets in -# docstrings. -# -# This only has an effect when the `docstring-code-format` setting is -# enabled. -docstring-code-line-length = "dynamic" diff --git a/visual-retrieval-colpali/src/README.md b/visual-retrieval-colpali/src/README.md new file mode 100644 index 000000000..e07782c05 --- /dev/null +++ b/visual-retrieval-colpali/src/README.md @@ -0,0 +1,19 @@ +--- +title: ColPali 🤝 Vespa - Visual Retrieval +short_description: Visual Retrieval with ColPali and Vespa +emoji: 👀 +colorFrom: purple +colorTo: blue +sdk: gradio +sdk_version: 4.44.0 +app_file: main.py +pinned: false +license: apache-2.0 +suggested_hardware: t4-small +models: + - vidore/colpaligemma-3b-pt-448-base + - vidore/colpali-v1.2 +preload_from_hub: + - vidore/colpaligemma-3b-pt-448-base config.json,model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,model.safetensors.index.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa + - vidore/colpali-v1.2 adapter_config.json,adapter_model.safetensors,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 9912ce6f8a462d8cf2269f5606eabbd2784e764f +--- \ No newline at end of file diff --git a/visual-retrieval-colpali/backend/__init__.py b/visual-retrieval-colpali/src/backend/__init__.py similarity index 100% rename from visual-retrieval-colpali/backend/__init__.py rename to visual-retrieval-colpali/src/backend/__init__.py diff --git a/visual-retrieval-colpali/backend/cache.py b/visual-retrieval-colpali/src/backend/cache.py similarity index 100% rename from visual-retrieval-colpali/backend/cache.py rename to visual-retrieval-colpali/src/backend/cache.py diff --git a/visual-retrieval-colpali/src/backend/colpali.py b/visual-retrieval-colpali/src/backend/colpali.py new file mode 100644 index 000000000..d7a2503d4 --- /dev/null +++ b/visual-retrieval-colpali/src/backend/colpali.py @@ -0,0 +1,274 @@ +import torch +from PIL import Image +import numpy as np +from typing import Generator, Tuple, List, Union, Dict +from pathlib import Path +import base64 +from io import BytesIO +import re +import io +import matplotlib.cm as cm + +from colpali_engine.models import ColPali, ColPaliProcessor +from colpali_engine.utils.torch_utils import get_torch_device +from vidore_benchmark.interpretability.torch_utils import ( + normalize_similarity_map_per_query_token, +) + + +class SimMapGenerator: + """ + Generates similarity maps based on query embeddings and image patches using the ColPali model. + """ + + COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base" + colormap = cm.get_cmap("viridis") # Preload colormap for efficiency + + def __init__(self, model_name: str = "vidore/colpali-v1.2", n_patch: int = 32): + """ + Initializes the SimMapGenerator class with a specified model and patch dimension. + + Args: + model_name (str): The model name for loading the ColPali model. + n_patch (int): The number of patches per dimension. + """ + self.model_name = model_name + self.n_patch = n_patch + self.device = get_torch_device("auto") + print(f"Using device: {self.device}") + self.model, self.processor = self.load_model() + + def load_model(self) -> Tuple[ColPali, ColPaliProcessor]: + """ + Loads the ColPali model and processor. + + Returns: + Tuple[ColPali, ColPaliProcessor]: Loaded model and processor. + """ + model = ColPali.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, + device_map=self.device, + ).eval() + + processor = ColPaliProcessor.from_pretrained(self.model_name) + return model, processor + + def gen_similarity_maps( + self, + query: str, + query_embs: torch.Tensor, + token_idx_map: Dict[int, str], + images: List[Union[Path, str]], + vespa_sim_maps: List[Dict], + ) -> Generator[Tuple[int, str, str], None, None]: + """ + Generates similarity maps for the provided images and query, and returns base64-encoded blended images. + + Args: + query (str): The query string. + query_embs (torch.Tensor): Query embeddings tensor. + token_idx_map (dict): Mapping from indices to tokens. + images (List[Union[Path, str]]): List of image paths or base64-encoded strings. + vespa_sim_maps (List[Dict]): List of Vespa similarity maps. + + Yields: + Tuple[int, str, str]: A tuple containing the image index, selected token, and base64-encoded image. + """ + processed_images, original_images, original_sizes = [], [], [] + for img in images: + img_pil = self._load_image(img) + original_images.append(img_pil.copy()) + original_sizes.append(img_pil.size) + processed_images.append(img_pil) + + vespa_sim_map_tensor = self._prepare_similarity_map_tensor( + query_embs, vespa_sim_maps + ) + similarity_map_normalized = normalize_similarity_map_per_query_token( + vespa_sim_map_tensor + ) + + for idx, img in enumerate(original_images): + for token_idx, token in token_idx_map.items(): + if self.should_filter_token(token): + continue + + sim_map = similarity_map_normalized[idx, token_idx, :, :] + blended_img_base64 = self._blend_image( + img, sim_map, original_sizes[idx] + ) + yield idx, token, token_idx, blended_img_base64 + + def _load_image(self, img: Union[Path, str]) -> Image: + """ + Loads an image from a file path or a base64-encoded string. + + Args: + img (Union[Path, str]): The image to load. + + Returns: + Image: The loaded PIL image. + """ + try: + if isinstance(img, Path): + return Image.open(img).convert("RGB") + elif isinstance(img, str): + return Image.open(BytesIO(base64.b64decode(img))).convert("RGB") + except Exception as e: + raise ValueError(f"Failed to load image: {e}") + + def _prepare_similarity_map_tensor( + self, query_embs: torch.Tensor, vespa_sim_maps: List[Dict] + ) -> torch.Tensor: + """ + Prepares a similarity map tensor from Vespa similarity maps. + + Args: + query_embs (torch.Tensor): Query embeddings tensor. + vespa_sim_maps (List[Dict]): List of Vespa similarity maps. + + Returns: + torch.Tensor: The prepared similarity map tensor. + """ + vespa_sim_map_tensor = torch.zeros( + (len(vespa_sim_maps), query_embs.size(1), self.n_patch, self.n_patch) + ) + for idx, vespa_sim_map in enumerate(vespa_sim_maps): + for cell in vespa_sim_map["quantized"]["cells"]: + patch = int(cell["address"]["patch"]) + query_token = int(cell["address"]["querytoken"]) + value = cell["value"] + if hasattr(self.processor, "image_seq_length"): + image_seq_length = self.processor.image_seq_length + else: + image_seq_length = 1024 + + if patch >= image_seq_length: + continue + vespa_sim_map_tensor[ + idx, + query_token, + patch // self.n_patch, + patch % self.n_patch, + ] = value + return vespa_sim_map_tensor + + def _blend_image( + self, img: Image, sim_map: torch.Tensor, original_size: Tuple[int, int] + ) -> str: + """ + Blends an image with a similarity map and encodes it to base64. + + Args: + img (Image): The original image. + sim_map (torch.Tensor): The similarity map tensor. + original_size (Tuple[int, int]): The original size of the image. + + Returns: + str: The base64-encoded blended image. + """ + SCALING_FACTOR = 8 + sim_map_resolution = ( + max(32, int(original_size[0] / SCALING_FACTOR)), + max(32, int(original_size[1] / SCALING_FACTOR)), + ) + + sim_map_np = sim_map.cpu().float().numpy() + sim_map_img = Image.fromarray(sim_map_np).resize( + sim_map_resolution, resample=Image.BICUBIC + ) + sim_map_resized_np = np.array(sim_map_img, dtype=np.float32) + sim_map_normalized = self._normalize_sim_map(sim_map_resized_np) + + heatmap = self.colormap(sim_map_normalized) + heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)).convert("RGBA") + + buffer = io.BytesIO() + heatmap_img.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + @staticmethod + def _normalize_sim_map(sim_map: np.ndarray) -> np.ndarray: + """ + Normalizes a similarity map to range [0, 1]. + + Args: + sim_map (np.ndarray): The similarity map. + + Returns: + np.ndarray: The normalized similarity map. + """ + sim_map_min, sim_map_max = sim_map.min(), sim_map.max() + if sim_map_max - sim_map_min > 1e-6: + return (sim_map - sim_map_min) / (sim_map_max - sim_map_min) + return np.zeros_like(sim_map) + + @staticmethod + def should_filter_token(token: str) -> bool: + """ + Determines if a token should be filtered out based on predefined patterns. + + The function filters out tokens that: + + - Start with '<' (e.g., '') + - Consist entirely of whitespace + - Are purely punctuation (excluding tokens that contain digits or start with '▁') + - Start with an underscore '_' + - Exactly match the word 'Question' + - Are exactly the single character '▁' + + Output of test: + Token: '2' | False + Token: '0' | False + Token: '2' | False + Token: '3' | False + Token: '▁2' | False + Token: '▁hi' | False + Token: 'norwegian' | False + Token: 'unlisted' | False + Token: '' | True + Token: 'Question' | True + Token: ':' | True + Token: '' | True + Token: '\n' | True + Token: '▁' | True + Token: '?' | True + Token: ')' | True + Token: '%' | True + Token: '/)' | True + + + Args: + token (str): The token to check. + + Returns: + bool: True if the token should be filtered out, False otherwise. + """ + pattern = re.compile( + r"^<.*$|^\s+$|^(?!.*\d)(?!▁)[^\w\s]+$|^_.*$|^Question$|^▁$" + ) + return bool(pattern.match(token)) + + # TODO: Would be nice to @lru_cache this method. + def get_query_embeddings_and_token_map( + self, query: str + ) -> Tuple[torch.Tensor, dict]: + """ + Retrieves query embeddings and a token index map. + + Args: + query (str): The query string. + + Returns: + Tuple[torch.Tensor, dict]: Query embeddings and token index map. + """ + inputs = self.processor.process_queries([query]).to(self.model.device) + with torch.no_grad(): + q_emb = self.model(**inputs).to("cpu")[0] + + query_tokens = self.processor.tokenizer.tokenize( + self.processor.decode(inputs.input_ids[0]) + ) + idx_to_token = {idx: token for idx, token in enumerate(query_tokens)} + return q_emb, idx_to_token diff --git a/visual-retrieval-colpali/backend/modelmanager.py b/visual-retrieval-colpali/src/backend/modelmanager.py similarity index 100% rename from visual-retrieval-colpali/backend/modelmanager.py rename to visual-retrieval-colpali/src/backend/modelmanager.py diff --git a/visual-retrieval-colpali/backend/stopwords.py b/visual-retrieval-colpali/src/backend/stopwords.py similarity index 100% rename from visual-retrieval-colpali/backend/stopwords.py rename to visual-retrieval-colpali/src/backend/stopwords.py diff --git a/visual-retrieval-colpali/backend/testquery.py b/visual-retrieval-colpali/src/backend/testquery.py similarity index 100% rename from visual-retrieval-colpali/backend/testquery.py rename to visual-retrieval-colpali/src/backend/testquery.py diff --git a/visual-retrieval-colpali/backend/vespa_app.py b/visual-retrieval-colpali/src/backend/vespa_app.py similarity index 99% rename from visual-retrieval-colpali/backend/vespa_app.py rename to visual-retrieval-colpali/src/backend/vespa_app.py index c553a3d8f..b5a53c76d 100644 --- a/visual-retrieval-colpali/backend/vespa_app.py +++ b/visual-retrieval-colpali/src/backend/vespa_app.py @@ -7,9 +7,10 @@ from dotenv import load_dotenv from vespa.application import Vespa from vespa.io import VespaQueryResponse -from .colpali import should_filter_token +from .colpali import SimMapGenerator import backend.stopwords + class VespaQueryClient: MAX_QUERY_TERMS = 64 VESPA_SCHEMA_NAME = "pdf_page" @@ -364,7 +365,7 @@ def results_to_search_results( fields_to_add = [ f"sim_map_{token}_{idx}" for idx, token in idx_to_token.items() - if not should_filter_token(token) + if not SimMapGenerator.should_filter_token(token) ] for child in result["root"]["children"]: for sim_map_key in fields_to_add: diff --git a/visual-retrieval-colpali/frontend/__init__.py b/visual-retrieval-colpali/src/frontend/__init__.py similarity index 100% rename from visual-retrieval-colpali/frontend/__init__.py rename to visual-retrieval-colpali/src/frontend/__init__.py diff --git a/visual-retrieval-colpali/frontend/app.py b/visual-retrieval-colpali/src/frontend/app.py similarity index 100% rename from visual-retrieval-colpali/frontend/app.py rename to visual-retrieval-colpali/src/frontend/app.py diff --git a/visual-retrieval-colpali/frontend/layout.py b/visual-retrieval-colpali/src/frontend/layout.py similarity index 100% rename from visual-retrieval-colpali/frontend/layout.py rename to visual-retrieval-colpali/src/frontend/layout.py diff --git a/visual-retrieval-colpali/globals.css b/visual-retrieval-colpali/src/globals.css similarity index 100% rename from visual-retrieval-colpali/globals.css rename to visual-retrieval-colpali/src/globals.css diff --git a/visual-retrieval-colpali/icons.py b/visual-retrieval-colpali/src/icons.py similarity index 100% rename from visual-retrieval-colpali/icons.py rename to visual-retrieval-colpali/src/icons.py diff --git a/visual-retrieval-colpali/main.py b/visual-retrieval-colpali/src/main.py similarity index 86% rename from visual-retrieval-colpali/main.py rename to visual-retrieval-colpali/src/main.py index 8a2d64f0d..90c0f45f7 100644 --- a/visual-retrieval-colpali/main.py +++ b/visual-retrieval-colpali/src/main.py @@ -28,8 +28,7 @@ from shad4fast import ShadHead from vespa.application import Vespa -from backend.colpali import gen_similarity_maps, get_query_embeddings_and_token_map -from backend.modelmanager import ModelManager +from backend.colpali import SimMapGenerator from backend.vespa_app import VespaQueryClient from frontend.app import ( AboutThisDemo, @@ -108,7 +107,7 @@ @app.on_event("startup") def load_model_on_startup(): - app.manager = ModelManager.get_instance() + app.sim_map_generator = SimMapGenerator() return @@ -141,19 +140,16 @@ def get(): @rt("/search") -def get(request): - # Extract the 'query' and 'ranking' parameters from the URL - query_value = request.query_params.get("query", "").strip() - ranking_value = request.query_params.get("ranking", "nn+colpali") - print("/search: Fetching results for ranking_value:", ranking_value) +def get(request, query: str = "", ranking: str = "nn+colpali"): + print("/search: Fetching results for ranking_value:", ranking) # Always render the SearchBox first - if not query_value: + if not query: # Show SearchBox and a message for missing query return Layout( Main( Div( - SearchBox(query_value=query_value, ranking_value=ranking_value), + SearchBox(query_value=query, ranking_value=ranking), Div( P( "No query provided. Please enter a query.", @@ -166,35 +162,17 @@ def get(request): ) ) # Generate a unique query_id based on the query and ranking value - query_id = generate_query_id(query_value, ranking_value) + query_id = generate_query_id(query, ranking) # Show the loading message if a query is provided return Layout( Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"), Aside( - ChatResult(query_id=query_id, query=query_value), + ChatResult(query_id=query_id, query=query), cls="border-t border-l hidden md:block", ), ) # Show SearchBox and Loading message initially -@rt("/fetch_results2") -def get(query: str, ranking: str): - # 1. Get the results from Vespa (without sim_maps and full_images) - # Call search-endpoint in Vespa sync. - - # 2. Kick off tasks to fetch sim_maps and full_images - # Sim maps - call search endpoint async. - # (A) New rank_profile that does not calculate sim_maps. - # (A) Make vespa endpoints take select_fields as a parameter. - # One sim map per image per token. - # the filename query_id_result_idx_token_idx.png - # Full image. based on the doc_id. - # Each of these tasks saves to disk. - # Need a cleanup task to delete old files. - # Polling endpoints for sim_maps and full_images checks if file exists and returns it. - pass - - @rt("/fetch_results") async def get(session, request, query: str, ranking: str): if "hx-request" not in request.headers: @@ -204,9 +182,10 @@ async def get(session, request, query: str, ranking: str): query_id = generate_query_id(query, ranking) print(f"Query id in /fetch_results: {query_id}") # Run the embedding and query against Vespa app - model = app.manager.model - processor = app.manager.processor - q_embs, idx_to_token = get_query_embeddings_and_token_map(processor, model, query) + + q_embs, idx_to_token = app.sim_map_generator.get_query_embeddings_and_token_map( + query + ) start = time.perf_counter() # Fetch real search results from Vespa @@ -275,10 +254,7 @@ def get_and_store_sim_maps( if not all([os.path.exists(img_path) for img_path in img_paths]): print(f"Images not ready in 5 seconds for query_id: {query_id}") return False - sim_map_generator = gen_similarity_maps( - model=app.manager.model, - processor=app.manager.processor, - device=app.manager.device, + sim_map_generator = app.sim_map_generator.gen_similarity_maps( query=query, query_embs=q_embs, token_idx_map=idx_to_token, @@ -340,8 +316,9 @@ async def full_image(doc_id: str): @rt("/suggestions") -async def get_suggestions(request): - query = request.query_params.get("query", "").lower().strip() +async def get_suggestions(query: str = ""): + """Endpoint to get suggestions as user types in the search box""" + query = query.lower().strip() if query: suggestions = await vespa_app.get_suggestions(query) @@ -352,12 +329,16 @@ async def get_suggestions(request): async def message_generator(query_id: str, query: str, doc_ids: list): + """Generator function to yield SSE messages for chat response""" images = {} num_images = 3 # Number of images before firing chat request max_wait = 10 # seconds start_time = time.time() # Check if full images are ready on disk - while len(images) < min(num_images, len(doc_ids)) and time.time() - start_time < max_wait: + while ( + len(images) < min(num_images, len(doc_ids)) + and time.time() - start_time < max_wait + ): for idx in range(num_images): image_filename = IMG_DIR / f"{doc_ids[idx]}.jpg" if not os.path.exists(image_filename): diff --git a/visual-retrieval-colpali/output.css b/visual-retrieval-colpali/src/output.css similarity index 100% rename from visual-retrieval-colpali/output.css rename to visual-retrieval-colpali/src/output.css diff --git a/visual-retrieval-colpali/requirements.txt b/visual-retrieval-colpali/src/requirements.txt similarity index 100% rename from visual-retrieval-colpali/requirements.txt rename to visual-retrieval-colpali/src/requirements.txt diff --git a/visual-retrieval-colpali/static/img/vespa-colpali.png b/visual-retrieval-colpali/src/static/img/vespa-colpali.png similarity index 100% rename from visual-retrieval-colpali/static/img/vespa-colpali.png rename to visual-retrieval-colpali/src/static/img/vespa-colpali.png diff --git a/visual-retrieval-colpali/static/js/highlightjs-theme.js b/visual-retrieval-colpali/src/static/js/highlightjs-theme.js similarity index 100% rename from visual-retrieval-colpali/static/js/highlightjs-theme.js rename to visual-retrieval-colpali/src/static/js/highlightjs-theme.js diff --git a/visual-retrieval-colpali/tailwind.config.js b/visual-retrieval-colpali/src/tailwind.config.js similarity index 100% rename from visual-retrieval-colpali/tailwind.config.js rename to visual-retrieval-colpali/src/tailwind.config.js diff --git a/visual-retrieval-colpali/tailwindcss b/visual-retrieval-colpali/src/tailwindcss similarity index 100% rename from visual-retrieval-colpali/tailwindcss rename to visual-retrieval-colpali/src/tailwindcss diff --git a/visual-retrieval-colpali/static/assets/ConocoPhillips Sustainability Highlights - Nature (24-0976).png b/visual-retrieval-colpali/static/assets/ConocoPhillips Sustainability Highlights - Nature (24-0976).png deleted file mode 100644 index 241d09566..000000000 Binary files a/visual-retrieval-colpali/static/assets/ConocoPhillips Sustainability Highlights - Nature (24-0976).png and /dev/null differ