Skip to content

Commit

Permalink
Merge pull request #34 from sdan/sdan/sd-208-feat-binary-embeddings-i…
Browse files Browse the repository at this point in the history
…nt8-rescorer

binary embeddings & int8 rescoring
  • Loading branch information
sdan authored Apr 4, 2024
2 parents c61cbae + 1f8d05a commit 712bf92
Show file tree
Hide file tree
Showing 6 changed files with 4,615 additions and 62 deletions.
8 changes: 6 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ chromadb==0.4.24
qdrant-client
git+https://github.com/sdan/surya.git
beautifulsoup4==4.12.3
llama-cpp-python==0.2.58
llama-cpp-python==0.2.59
huggingface_hub
fastapi==0.110.1
python-multipart
python-multipart
lancedb
langchain==0.1.14
langchain-community
faiss-cpu=1.8.0
115 changes: 109 additions & 6 deletions tests/benchtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
from qdrant_client.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer

from langchain_community.embeddings import HuggingFaceEmbeddings

from langchain.document_loaders import TextLoader
from langchain.vectorstores import LanceDB, Lantern, FAISS
from langchain.text_splitter import CharacterTextSplitter

from langchain.docstore.document import Document


def main(queries, corpuss, top_k, token_counts) -> pd.DataFrame:
"""Run the VLite benchmark.
Expand Down Expand Up @@ -84,7 +92,7 @@ def main(queries, corpuss, top_k, token_counts) -> pd.DataFrame:

results.append(
{
"num_embeddings": len(results_retrieve),
"num_embeddings": len(result_add),
"lib": "VLite",
"k": top_k,
"avg_time": np.mean(times),
Expand All @@ -95,6 +103,101 @@ def main(queries, corpuss, top_k, token_counts) -> pd.DataFrame:
print(json.dumps(results[-1], indent=2))
print("Done VLite benchmark.")


#################################################
# LanceDB #
#################################################
print("Begin LanceDB benchmark.")
print("Adding documents to LanceDB instance...")
t0 = time.time()

embeddings = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
documents = [Document(page_content=text) for text in corpus]
docsearch = LanceDB.from_documents(documents, embeddings)

t1 = time.time()
print(f"Took {t1 - t0:.3f}s to add documents.")
indexing_times.append(
{
"num_tokens": token_count,
"lib": "LanceDB",
"num_embeddings": len(documents),
"indexing_time": t1 - t0,
}
)

print("Starting LanceDB trials...")
times = []
for query in queries:
t0 = time.time()
docs = docsearch.similarity_search(query, k=top_k)
t1 = time.time()
times.append(t1 - t0)

print(f"Top {top_k} results for query '{query}':")
for doc in docs:
print(f"Text: {doc.page_content}\n---")

results.append(
{
"num_embeddings": len(documents),
"lib": "LanceDB",
"k": top_k,
"avg_time": np.mean(times),
"stddev_time": np.std(times),
}
)

print(json.dumps(results[-1], indent=2))
print("Done LanceDB benchmark.")

#################################################
# FAISS #
#################################################
print("Begin FAISS benchmark.")
print("Adding documents to FAISS instance...")
t0 = time.time()

embeddings = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
documents = [Document(page_content=text) for text in corpus]
db = FAISS.from_documents(documents, embeddings)

t1 = time.time()
print(f"Took {t1 - t0:.3f}s to add documents.")
indexing_times.append(
{
"num_tokens": token_count,
"lib": "FAISS",
"num_embeddings": len(corpus),
"indexing_time": t1 - t0,
}
)

print("Starting FAISS trials...")
times = []
for query in queries:
t0 = time.time()
docs = db.similarity_search(query, k=top_k)
t1 = time.time()
times.append(t1 - t0)

print(f"Top {top_k} results for query '{query}':")
for doc in docs:
print(f"Text: {doc.page_content}\n---")

results.append(
{
"num_embeddings": len(corpus),
"lib": "FAISS",
"k": top_k,
"avg_time": np.mean(times),
"stddev_time": np.std(times),
}
)

print(json.dumps(results[-1], indent=2))
print("Done FAISS benchmark.")

#################################################
# Chroma #
#################################################
Expand All @@ -104,7 +207,7 @@ def main(queries, corpuss, top_k, token_counts) -> pd.DataFrame:

chroma_client = chromadb.Client()
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="mixedbread-ai/mxbai-embed-large-v1")
collection = chroma_client.create_collection(name="my_collection", embedding_function=sentence_transformer_ef)
collection = chroma_client.get_or_create_collection(name="benchtest", embedding_function=sentence_transformer_ef)
ids = [str(i) for i in range(len(corpus))]
try:
collection.add(documents=corpus, ids=ids)
Expand Down Expand Up @@ -176,10 +279,10 @@ def main(queries, corpuss, top_k, token_counts) -> pd.DataFrame:
points=[
PointStruct(
id=idx,
vector=model.encode(vector).tolist(),
payload={"text": corpus[idx]}
vector=model.encode(text).tolist(),
payload={"text": text}
)
for idx, vector in enumerate(corpus)
for idx, text in enumerate(corpus)
]
)
except Exception as e:
Expand All @@ -203,7 +306,7 @@ def main(queries, corpuss, top_k, token_counts) -> pd.DataFrame:

for i in range(len(query)):
t0 = time.time()
query_vector = model.encode(query[i]).tolist()
query_vector = model.encode(query).tolist()
try:
hits = qdrant_client.search(
collection_name="my_collection",
Expand Down
Loading

0 comments on commit 712bf92

Please sign in to comment.