Skip to content

Commit

Permalink
adapt main
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasht86 committed Nov 1, 2024
1 parent 5cbe3b0 commit 9d203c3
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions visual-retrieval-colpali/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from shad4fast import ShadHead
from vespa.application import Vespa

from backend.colpali import gen_similarity_maps, get_query_embeddings_and_token_map
from backend.modelmanager import ModelManager
from backend.colpali import SimMapGenerator
from backend.vespa_app import VespaQueryClient
from frontend.app import (
AboutThisDemo,
Expand Down Expand Up @@ -108,7 +107,7 @@

@app.on_event("startup")
def load_model_on_startup():
app.manager = ModelManager.get_instance()
app.sim_map_generator = SimMapGenerator()
return


Expand Down Expand Up @@ -183,9 +182,10 @@ async def get(session, request, query: str, ranking: str):
query_id = generate_query_id(query, ranking)
print(f"Query id in /fetch_results: {query_id}")
# Run the embedding and query against Vespa app
model = app.manager.model
processor = app.manager.processor
q_embs, idx_to_token = get_query_embeddings_and_token_map(processor, model, query)

q_embs, idx_to_token = app.sim_map_generator.get_query_embeddings_and_token_map(
query
)

start = time.perf_counter()
# Fetch real search results from Vespa
Expand Down Expand Up @@ -254,10 +254,7 @@ def get_and_store_sim_maps(
if not all([os.path.exists(img_path) for img_path in img_paths]):
print(f"Images not ready in 5 seconds for query_id: {query_id}")
return False
sim_map_generator = gen_similarity_maps(
model=app.manager.model,
processor=app.manager.processor,
device=app.manager.device,
sim_map_generator = app.sim_map_generator.gen_similarity_maps(
query=query,
query_embs=q_embs,
token_idx_map=idx_to_token,
Expand Down

0 comments on commit 9d203c3

Please sign in to comment.