Skip to content

Commit

Permalink
test csr matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Thong Nguyen committed Jan 5, 2024
1 parent 115c212 commit 086fb7d
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion retrieve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from scipy.sparse import csr_matrix
from numba import types, typed, njit
from numba.experimental import jitclass
from pathlib import Path
Expand Down Expand Up @@ -147,8 +148,37 @@ def create_json_query(query_id, topk_toks, topk_weights):
# score += self.text_forward[q_id][tok] * \
# self.image_forward[d_id][tok]
# return score
if args.mode == "csr":
rows = []
cols = []
data = []
num_images = len(sparse_images)
for idx, img in enumerate(sparse_images):
toks = list(img["toks"].keys())
rows.extend(tokenizer.convert_tokens_to_ids(toks))
cols.extend([idx]*len(toks))
data.extend(list(img["toks"].values()))
image_csr = csr_matrix((data, (rows, cols)), shape=(30522, num_images))

num_texts = len(sparse_texts)
rows = []
cols = []
data = []
for idx, text in enumerate(sparse_texts):
toks = list(text["query_toks"].keys())
cols.extend(tokenizer.convert_tokens_to_ids(toks))
data.extend(list(text["query_toks"].values()))
rows.extend([idx]*len(toks))
text_csr = csr_matrix((data, (rows, cols)), shape=(num_texts, 30522))
start = time.time()
scores = text_csr @ image_csr
end = time.time()
total_time = end - start
print(f"Total running time: {total_time} seconds")
print(f"s/q: {total_time*1.0/num_texts}")
print(f"q/s: {num_texts*1.0/total_time}")

if args.mode == "faiss":
elif args.mode == "faiss":
import faiss
faiss.omp_set_num_threads(1)
num_images = len(sparse_images)
Expand Down

0 comments on commit 086fb7d

Please sign in to comment.