-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathui.py
110 lines (95 loc) · 3.62 KB
/
ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from streamlit import cache_resource
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
from pymilvus import (
utility,
FieldSchema, CollectionSchema, DataType,
Collection, AnnSearchRequest, WeightedRanker, connections,
)
@cache_resource
def get_model():
ef = BGEM3EmbeddingFunction(use_fp16=False, device="cpu")
return ef
@cache_resource
def get_collection():
col_name = 'hybrid_demo'
connections.connect("default", uri="milvus.db")
col = Collection(col_name)
return col
st.title("Milvus Hybird Search Demo")
query = st.text_input("Enter your search query:")
search_button = st.button("Search")
@cache_resource
def get_tokenizer():
ef = get_model()
tokenizer = ef.model.tokenizer
return tokenizer
def doc_text_colorization(query, docs):
tokenizer = get_tokenizer()
query_tokens_ids = tokenizer.encode(query, return_offsets_mapping=True)
query_tokens = tokenizer.convert_ids_to_tokens(query_tokens_ids)
colored_texts = []
for doc in docs:
ldx = 0
landmarks = []
encoding = tokenizer.encode_plus(doc, return_offsets_mapping=True)
tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])[1:-1]
offsets = encoding['offset_mapping'][1:-1]
for token, (start, end) in zip(tokens, offsets):
if token in query_tokens:
if len(landmarks) != 0 and start == landmarks[-1]:
landmarks[-1] = end
else:
landmarks.append(start)
landmarks.append(end)
close = False
color_text = ''
for i, c in enumerate(doc):
if ldx == len(landmarks):
pass
elif i == landmarks[ldx]:
if close is True:
color_text += ']'
else:
color_text += ':red['
close = not close
ldx = ldx + 1
color_text += c
if close is True:
color_text += ']'
colored_texts.append(color_text)
return colored_texts
def hybrid_search(query_embeddings, sparse_weight=1.0, dense_weight=1.0):
col = get_collection()
sparse_search_params = {"metric_type": "IP"}
sparse_req = AnnSearchRequest(query_embeddings["sparse"], "sparse_vector", sparse_search_params, limit=10)
dense_search_params = {"metric_type": "IP"}
dense_req = AnnSearchRequest(query_embeddings["dense"], "dense_vector", dense_search_params, limit=10)
rerank = WeightedRanker(sparse_weight, dense_weight)
res = col.hybrid_search([sparse_req, dense_req], rerank=rerank, limit=10, output_fields=['text'])
if len(res):
return [hit.fields["text"] for hit in res[0]]
else:
return []
# Display search results when the button is clicked
if search_button and query:
ef = get_model()
query_embeddings = ef([query])
col1, col2, col3 = st.columns(3)
with col1:
st.header("Dense")
results = hybrid_search(query_embeddings, sparse_weight=0.0, dense_weight=1.0)
for result in results:
st.markdown(result)
with col2:
st.header("Sparse")
results = hybrid_search(query_embeddings, sparse_weight=1.0, dense_weight=0.0)
colored_results = doc_text_colorization(query, results)
for result in colored_results:
st.markdown(result)
with col3:
st.header("Hybrid")
results = hybrid_search(query_embeddings, sparse_weight=0.7, dense_weight=1.0)
colored_results = doc_text_colorization(query, results)
for result in colored_results:
st.markdown(result)