Skip to content

Commit

Permalink
Merge pull request #1567 from vespa-engine/thomasht86/logging-not-print
Browse files Browse the repository at this point in the history
(colpalidemo) logging not print
  • Loading branch information
thomasht86 authored Nov 6, 2024
2 parents ad2f0a6 + 16ab8a3 commit 06817d2
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 71 deletions.
14 changes: 11 additions & 3 deletions visual-retrieval-colpali/src/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vidore_benchmark.interpretability.torch_utils import (
normalize_similarity_map_per_query_token,
)
from functools import lru_cache
import logging


class SimMapGenerator:
Expand All @@ -23,7 +25,12 @@ class SimMapGenerator:

colormap = cm.get_cmap("viridis") # Preload colormap for efficiency

def __init__(self, model_name: str = "vidore/colpali-v1.2", n_patch: int = 32):
def __init__(
self,
logger: logging.Logger,
model_name: str = "vidore/colpali-v1.2",
n_patch: int = 32,
):
"""
Initializes the SimMapGenerator class with a specified model and patch dimension.
Expand All @@ -34,7 +41,8 @@ def __init__(self, model_name: str = "vidore/colpali-v1.2", n_patch: int = 32):
self.model_name = model_name
self.n_patch = n_patch
self.device = get_torch_device("auto")
print(f"Using device: {self.device}")
self.logger = logger
self.logger.info(f"Using device: {self.device}")
self.model, self.processor = self.load_model()

def load_model(self) -> Tuple[ColPali, ColPaliProcessor]:
Expand Down Expand Up @@ -249,7 +257,7 @@ def should_filter_token(token: str) -> bool:
)
return bool(pattern.match(token))

# TODO: Would be nice to @lru_cache this method.
@lru_cache(maxsize=128)
def get_query_embeddings_and_token_map(
self, query: str
) -> Tuple[torch.Tensor, dict]:
Expand Down
24 changes: 0 additions & 24 deletions visual-retrieval-colpali/src/backend/modelmanager.py

This file was deleted.

25 changes: 13 additions & 12 deletions visual-retrieval-colpali/src/backend/vespa_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@
from vespa.io import VespaQueryResponse
from .colpali import SimMapGenerator
import backend.stopwords
import logging


class VespaQueryClient:
MAX_QUERY_TERMS = 64
VESPA_SCHEMA_NAME = "pdf_page"
SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text"

def __init__(self):
def __init__(self, logger: logging.Logger):
"""
Initialize the VespaQueryClient by loading environment variables and establishing a connection to the Vespa application.
"""
load_dotenv()
self.logger = logger

if os.environ.get("USE_MTLS") == "true":
print("Connected using mTLS")
self.logger.info("Connected using mTLS")
mtls_key = os.environ.get("VESPA_CLOUD_MTLS_KEY")
mtls_cert = os.environ.get("VESPA_CLOUD_MTLS_CERT")

Expand Down Expand Up @@ -52,7 +54,7 @@ def __init__(self):
url=self.vespa_app_url, key=mtls_key_path, cert=mtls_cert_path
)
else:
print("Connected using token")
self.logger.info("Connected using token")
self.vespa_app_url = os.environ.get("VESPA_APP_TOKEN_URL")
if not self.vespa_app_url:
raise ValueError(
Expand All @@ -73,7 +75,7 @@ def __init__(self):
)

self.app.wait_for_application_up()
print(f"Connected to Vespa at {self.vespa_app_url}")
self.logger.info(f"Connected to Vespa at {self.vespa_app_url}")

def get_fields(self, sim_map: bool = False):
if not sim_map:
Expand All @@ -99,7 +101,7 @@ def format_query_results(
query_time = round(query_time, 2)
count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
print(result_text)
self.logger.debug(result_text)
return response.json

async def query_vespa_default(
Expand Down Expand Up @@ -143,7 +145,7 @@ async def query_vespa_default(
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(
self.logger.debug(
f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
Expand Down Expand Up @@ -190,7 +192,7 @@ async def query_vespa_bm25(
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(
self.logger.debug(
f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
Expand All @@ -215,7 +217,7 @@ def float_to_binary_embedding(self, float_query_embedding: dict) -> dict:
)
binary_query_embeddings[key] = binary_vector
if len(binary_query_embeddings) >= self.MAX_QUERY_TERMS:
print(
self.logger.warning(
f"Warning: Query has more than {self.MAX_QUERY_TERMS} terms. Truncating."
)
break
Expand Down Expand Up @@ -292,12 +294,11 @@ async def get_result_from_query(
result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
else:
raise ValueError(f"Unsupported ranking: {rank_method}")
# Print score, title id, and text of the results
if "root" not in result or "children" not in result["root"]:
result["root"] = {"children": []}
return result
for single_result in result["root"]["children"]:
print(single_result["fields"].keys())
self.logger.debug(single_result["fields"].keys())
return result

def get_sim_maps_from_query(
Expand Down Expand Up @@ -349,7 +350,7 @@ async def get_full_image_from_vespa(self, doc_id: str) -> str:
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(
self.logger.debug(
f"Getting image from Vespa took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
Expand Down Expand Up @@ -386,7 +387,7 @@ async def get_suggestions(self, query: str) -> list:
)
assert response.is_successful(), response.json
stop = time.perf_counter()
print(
self.logger.debug(
f"Getting suggestions from Vespa took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
Expand Down
25 changes: 10 additions & 15 deletions visual-retrieval-colpali/src/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
autocomplete_script,
action=f"/search?query={quote_plus(query_value)}&ranking={quote_plus(ranking_value)}",
method="GET",
hx_get=f"/fetch_results?query={quote_plus(query_value)}&ranking={quote_plus(ranking_value)}",
hx_get="/fetch_results", # As the component is a form, input components query and ranking are sent as query parameters automatically, see https://htmx.org/docs/#parameters
hx_trigger="load",
hx_target="#search-results",
hx_swap="outerHTML",
Expand Down Expand Up @@ -310,9 +310,6 @@ def AboutThisDemo():
def Search(request, search_results=[]):
query_value = request.query_params.get("query", "").strip()
ranking_value = request.query_params.get("ranking", "nn+colpali")
print(
f"Search: Fetching results for query: {query_value}, ranking: {ranking_value}"
)
return Div(
Div(
Div(
Expand Down Expand Up @@ -381,7 +378,8 @@ def SearchInfo(search_time, total_count):

def SearchResult(
results: list,
query: str, query_id: Optional[str] = None,
query: str,
query_id: Optional[str] = None,
search_time: float = 0,
total_count: int = 0,
):
Expand Down Expand Up @@ -584,16 +582,13 @@ def SearchResult(
return [
Div(
SearchInfo(search_time, total_count),
*result_items,
image_swapping,
toggle_text_content,
dynamic_elements_scrollbars,
id="search-results",
cls="grid grid-cols-1 gap-px bg-border min-h-0",
)


,
*result_items,
image_swapping,
toggle_text_content,
dynamic_elements_scrollbars,
id="search-results",
cls="grid grid-cols-1 gap-px bg-border min-h-0",
),
Div(
ChatResult(query_id=query_id, query=query, doc_ids=doc_ids),
hx_swap_oob="true",
Expand Down
55 changes: 38 additions & 17 deletions visual-retrieval-colpali/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import time
import uuid
import logging
import sys
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

Expand Down Expand Up @@ -68,6 +70,20 @@
)
sselink = Script(src="https://unpkg.com/[email protected]/sse.js")

# Get log level from environment variable, default to INFO
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logger
logger = logging.getLogger("vespa_app")
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter(
"%(levelname)s: \t %(asctime)s \t %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
)
logger.addHandler(handler)
logger.setLevel(getattr(logging, LOG_LEVEL))

app, rt = fast_app(
htmlkw={"cls": "grid h-full"},
pico=False,
Expand All @@ -83,7 +99,7 @@
ShadHead(tw_cdn=False, theme_handle=True),
),
)
vespa_app: Vespa = VespaQueryClient()
vespa_app: Vespa = VespaQueryClient(logger=logger)
thread_pool = ThreadPoolExecutor()
# Gemini config

Expand All @@ -107,7 +123,7 @@

@app.on_event("startup")
def load_model_on_startup():
app.sim_map_generator = SimMapGenerator()
app.sim_map_generator = SimMapGenerator(logger=logger)
return


Expand Down Expand Up @@ -141,7 +157,7 @@ def get():

@rt("/search")
def get(request, query: str = "", ranking: str = "nn+colpali"):
print("/search: Fetching results for ranking_value:", ranking)
logger.info(f"/search: Fetching results for query: {query}, ranking: {ranking}")

# Always render the SearchBox first
if not query:
Expand Down Expand Up @@ -180,12 +196,16 @@ async def get(session, request, query: str, ranking: str):

# Get the hash of the query and ranking value
query_id = generate_query_id(query, ranking)
print(f"Query id in /fetch_results: {query_id}")
logger.info(f"Query id in /fetch_results: {query_id}")
# Run the embedding and query against Vespa app

start_inference = time.perf_counter()
q_embs, idx_to_token = app.sim_map_generator.get_query_embeddings_and_token_map(
query
)
end_inference = time.perf_counter()
logger.info(
f"Inference time for query_id: {query_id} \t {end_inference - start_inference:.2f} seconds"
)

start = time.perf_counter()
# Fetch real search results from Vespa
Expand All @@ -196,8 +216,8 @@ async def get(session, request, query: str, ranking: str):
idx_to_token=idx_to_token,
)
end = time.perf_counter()
print(
f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
logger.info(
f"Search results fetched in {end - start:.2f} seconds. Vespa search time: {result['timing']['searchtime']}"
)
search_time = result["timing"]["searchtime"]
total_count = result["root"]["fields"]["totalCount"]
Expand Down Expand Up @@ -228,7 +248,7 @@ async def poll_vespa_keepalive():
while True:
await asyncio.sleep(5)
await vespa_app.keepalive()
print(f"Vespa keepalive: {time.time()}")
logger.debug(f"Vespa keepalive: {time.time()}")


@threaded
Expand All @@ -252,7 +272,7 @@ def get_and_store_sim_maps(
):
time.sleep(0.2)
if not all([os.path.exists(img_path) for img_path in img_paths]):
print(f"Images not ready in 5 seconds for query_id: {query_id}")
logger.warning(f"Images not ready in 5 seconds for query_id: {query_id}")
return False
sim_map_generator = app.sim_map_generator.gen_similarity_maps(
query=query,
Expand All @@ -264,7 +284,7 @@ def get_and_store_sim_maps(
for idx, token, token_idx, blended_img_base64 in sim_map_generator:
with open(SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png", "wb") as f:
f.write(base64.b64decode(blended_img_base64))
print(
logger.debug(
f"Sim map saved to disk for query_id: {query_id}, idx: {idx}, token: {token}"
)
return True
Expand All @@ -279,7 +299,9 @@ async def get_sim_map(query_id: str, idx: int, token: str, token_idx: int):
"""
sim_map_path = SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png"
if not os.path.exists(sim_map_path):
print(f"Sim map not ready for query_id: {query_id}, idx: {idx}, token: {token}")
logger.debug(
f"Sim map not ready for query_id: {query_id}, idx: {idx}, token: {token}"
)
return SimMapButtonPoll(
query_id=query_id, idx=idx, token=token, token_idx=token_idx
)
Expand All @@ -304,7 +326,7 @@ async def full_image(doc_id: str):
# image data is base 64 encoded string. Save it to disk as jpg.
with open(img_path, "wb") as f:
f.write(base64.b64decode(image_data))
print(f"Full image saved to disk for doc_id: {doc_id}")
logger.debug(f"Full image saved to disk for doc_id: {doc_id}")
else:
with open(img_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode("utf-8")
Expand Down Expand Up @@ -343,16 +365,16 @@ async def message_generator(query_id: str, query: str, doc_ids: list):
for idx in range(num_images):
image_filename = IMG_DIR / f"{doc_ids[idx]}.jpg"
if not os.path.exists(image_filename):
print(
logger.debug(
f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}"
)
continue
else:
print(
logger.debug(
f"Message generator: image ready for query_id: {query_id}, idx: {idx}"
)
images.append(Image.open(image_filename))
if(len(images) < num_images):
if len(images) < num_images:
await asyncio.sleep(0.2)

# yield message with number of images ready
Expand Down Expand Up @@ -392,7 +414,6 @@ def get():


if __name__ == "__main__":
# ModelManager.get_instance() # Initialize once at startup
HOT_RELOAD = os.getenv("HOT_RELOAD", "False").lower() == "true"
print(f"Starting app with hot reload: {HOT_RELOAD}")
logger.info(f"Starting app with hot reload: {HOT_RELOAD}")
serve(port=7860, reload=HOT_RELOAD)

0 comments on commit 06817d2

Please sign in to comment.