Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hybrid retrieval #152

Merged
merged 9 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Optional

import pandas as pd

from buster.completers import Completion, DocumentAnswerer, UserInputs
from buster.llm_utils import QuestionReformulator
from buster.llm_utils import QuestionReformulator, get_openai_embedding
from buster.retriever import Retriever
from buster.validators import Validator

Expand Down Expand Up @@ -33,7 +33,7 @@ class BusterConfig:
"max_tokens": 3000,
"top_k": 3,
"thresh": 0.7,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you rebase / merge main? This should already be there at this point

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it's rebased. Not sure why it still shows this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait I checked main, we didn't do the change actually. Seems like a bug fix then :-)

}
)
prompt_formatter_cfg: dict = field(
Expand Down
16 changes: 13 additions & 3 deletions buster/documents_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from typing import Callable, Optional

import numpy as np
import pandas as pd
from tqdm import tqdm

Expand Down Expand Up @@ -63,7 +64,8 @@ def add(
self,
df: pd.DataFrame,
num_workers: int = 16,
embedding_fn: callable = get_openai_embedding,
embedding_fn: Callable[[str], np.ndarray] = get_openai_embedding,
sparse_embedding_fn: Callable[[str], dict[str, list[float]]] = None,
hbertrand marked this conversation as resolved.
Show resolved Hide resolved
csv_filename: Optional[str] = None,
csv_overwrite: bool = True,
**add_kwargs,
Expand All @@ -81,6 +83,8 @@ def add(
num_workers (int, optional): The number of parallel workers to use for computing embeddings. Default is 32.
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string.
Default is None. Only use if you want sparse embeddings.
csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to True.
**add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method.
Expand All @@ -92,6 +96,8 @@ def add(
# Check if embeddings are present, computes them if not
if "embedding" not in df.columns:
df["embedding"] = compute_embeddings_parallelized(df, embedding_fn=embedding_fn, num_workers=num_workers)
if "sparse_embedding" not in df.columns and sparse_embedding_fn is not None:
df["sparse_embedding"] = sparse_embedding_fn(df.content.to_list())

if csv_filename is not None:
self._checkpoint_csv(df, csv_filename=csv_filename, csv_overwrite=csv_overwrite)
Expand All @@ -104,7 +110,8 @@ def batch_add(
batch_size: int = 3000,
min_time_interval: int = 60,
num_workers: int = 16,
embedding_fn: callable = get_openai_embedding,
embedding_fn: Callable[[str], np.ndarray] = get_openai_embedding,
sparse_embedding_fn: Callable[[str], dict[str, list[float]]] = None,
csv_filename: Optional[str] = None,
csv_overwrite: bool = False,
**add_kwargs,
Expand All @@ -125,6 +132,8 @@ def batch_add(
Defaults to 32.
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string.
Default is None. Only use if you want sparse embeddings.
csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to False.
When using batches, set to False to keep all embeddings in the same file. You may want to manually remove the file if experimenting.
Expand All @@ -151,6 +160,7 @@ def batch_add(
csv_filename=csv_filename,
csv_overwrite=csv_overwrite,
embedding_fn=embedding_fn,
sparse_embedding_fn=sparse_embedding_fn,
**add_kwargs,
)

Expand Down
24 changes: 23 additions & 1 deletion buster/documents_manager/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def _add_documents(self, df: pd.DataFrame):
Args:
df: The dataframe containing the documents.
"""
use_sparse_vector = "sparse_embedding" in df.columns
hbertrand marked this conversation as resolved.
Show resolved Hide resolved
if use_sparse_vector:
logger.info("Uploading sparse embeddings too.")

for source in df.source.unique():
source_exists = self.db.sources.find_one({"name": source})
if source_exists is None:
Expand All @@ -75,14 +79,32 @@ def _add_documents(self, df: pd.DataFrame):
source_id = self.get_source_id(source)

df_source = df[df.source == source]
to_upsert = []
for row in df_source.to_dict(orient="records"):
embedding = row["embedding"].tolist()
if use_sparse_vector:
sparse_embedding = row["sparse_embedding"]

document = row.copy()
document.pop("embedding")
if use_sparse_vector:
document.pop("sparse_embedding")
document["source_id"] = source_id

document_id = str(self.db.documents.insert_one(document).inserted_id)
self.index.upsert([(document_id, embedding, {"source": source})], namespace=self.namespace)
vector = {"id": document_id, "values": embedding, "metadata": {"source": source}}
if use_sparse_vector:
vector["sparse_values"] = sparse_embedding

to_upsert.append(vector)

# Current (November 2023) Pinecone upload rules:
# - Max 1000 vectors per batch
# - Max 2 MB per batch
# Sparse vectors are heavier, so we reduce the batch size when using them.
MAX_PINECONE_BATCH_SIZE = 100 if use_sparse_vector else 1000
hbertrand marked this conversation as resolved.
Show resolved Hide resolved
for i in range(0, len(to_upsert), MAX_PINECONE_BATCH_SIZE):
self.index.upsert(vectors=to_upsert[i : i + MAX_PINECONE_BATCH_SIZE], namespace=self.namespace)

def update_source(self, source: str, display_name: str = None, note: str = None):
"""Update the display name and/or note of a source. Also create the source if it does not exist.
Expand Down
2 changes: 2 additions & 0 deletions buster/llm_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from buster.llm_utils.embeddings import (
BM25,
compute_embeddings_parallelized,
cosine_similarity,
get_openai_embedding,
Expand All @@ -12,4 +13,5 @@
get_openai_embedding,
compute_embeddings_parallelized,
get_openai_embedding_constructor,
BM25,
]
21 changes: 21 additions & 0 deletions buster/llm_utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
from openai import OpenAI
from pinecone_text.sparse import BM25Encoder
from tqdm.contrib.concurrent import thread_map

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,3 +65,23 @@ def compute_embeddings_parallelized(df: pd.DataFrame, embedding_fn: callable, nu

logger.info(f"Finished computing embeddings")
return embeddings


class BM25:
def __init__(self, path_to_params: str = None) -> None:
self.encoder = BM25Encoder()

if path_to_params:
self.encoder.load(path_to_params)

def fit(self, df: pd.DataFrame):
self.encoder.fit(df.content.to_list())

def dump_params(self, path: str):
self.encoder.dump(path)

def get_sparse_embedding_fn(self):
def sparse_embedding_fn(query: str):
return self.encoder.encode_queries(query)

return sparse_embedding_fn
25 changes: 11 additions & 14 deletions buster/retriever/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, Optional

import numpy as np
Expand All @@ -18,13 +17,22 @@

@dataclass
class Retriever(ABC):
def __init__(self, top_k: int, thresh: float, embedding_fn: Callable[[str], np.array] = None, *args, **kwargs):
def __init__(
self,
top_k: int,
thresh: float,
embedding_fn: Callable[[str], np.ndarray] = None,
sparse_embedding_fn: Callable[[str], dict[str, list[float]]] = None,
*args,
**kwargs,
):
"""Initializes a Retriever instance.

Args:
top_k: The maximum number of documents to retrieve.
thresh: The similarity threshold for document retrieval.
embedding_fn: The function to compute document embeddings.
embedding_fn: (Optional) The function to compute sparse document embeddings.
*args, **kwargs: Additional arguments and keyword arguments.
"""
if embedding_fn is None:
Expand All @@ -33,6 +41,7 @@ def __init__(self, top_k: int, thresh: float, embedding_fn: Callable[[str], np.a
self.top_k = top_k
self.thresh = thresh
self.embedding_fn = embedding_fn
self.sparse_embedding_fn = sparse_embedding_fn

# Add your access to documents in your own init

Expand Down Expand Up @@ -62,18 +71,6 @@ def get_source_display_name(self, source: str) -> str:
"""
...

def get_embedding(self, query: str) -> np.ndarray:
"""Generates the embedding of a query.

Args:
query: The query for which to generate the embedding.

Returns:
The embedding of the query as a NumPy array.
"""
logger.info("generating embedding")
return self.embedding_fn(query)

@abstractmethod
def get_topk_documents(self, query: str, source: Optional[str] = None, top_k: Optional[int] = None) -> pd.DataFrame:
"""Get the topk documents matching a user's query.
Expand Down
2 changes: 1 addition & 1 deletion buster/retriever/deeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_topk_documents(
The DataFrame containing the matched documents.
"""
if query is not None:
query_embedding = self.get_embedding(query)
query_embedding = self.embedding_fn(query)
elif embedding is not None:
query_embedding = embedding
else:
Expand Down
10 changes: 8 additions & 2 deletions buster/retriever/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,21 @@ def get_topk_documents(self, query: str, sources: Optional[List[str]], top_k: in
logger.warning(f"Sources {sources} do not exist. Returning empty dataframe.")
return pd.DataFrame()

query_embedding = self.get_embedding(query)
query_embedding = self.embedding_fn(query)
sparse_query_embedding = self.sparse_embedding_fn(query) if self.sparse_embedding_fn is not None else None

if isinstance(query_embedding, np.ndarray):
# pinecone expects a list of floats, so convert from ndarray if necessary
query_embedding = query_embedding.tolist()

# Pinecone retrieval
matches = self.index.query(
query_embedding, top_k=top_k, filter=filter, include_values=True, namespace=self.namespace
vector=query_embedding,
sparse_vector=sparse_query_embedding,
top_k=top_k,
filter=filter,
include_values=True,
namespace=self.namespace,
)["matches"]
matching_ids = [ObjectId(match.id) for match in matches]
matching_scores = {match.id: match.score for match in matches}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ numpy
openai>=1.0
pandas
pinecone-client
pinecone-text
pymongo
pytest
tabulate
Expand Down
10 changes: 5 additions & 5 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from buster.documents_manager import DeepLakeDocumentsManager
from buster.formatters.documents import DocumentsFormatterHTML
from buster.formatters.prompts import PromptFormatter
from buster.llm_utils import get_openai_embedding
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers.gpt import GPTTokenizer
from buster.validators import Validator
Expand Down Expand Up @@ -63,7 +64,7 @@
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
},
prompt_formatter_cfg={
"max_tokens": 3500,
Expand Down Expand Up @@ -130,6 +131,8 @@ def __init__(self, **kwargs):
}
)

self.embedding_fn = get_fake_embedding

def get_documents(self, source):
return self.documents

Expand All @@ -139,9 +142,6 @@ def get_topk_documents(self, query: str, sources: list[str] = None, top_k: int =
documents["similarity"] = [np.random.random() for _ in range(len(documents))]
return documents

def get_embedding(self, query, engine):
return get_fake_embedding()

def get_source_display_name(self, source):
return source

Expand Down Expand Up @@ -258,7 +258,7 @@ def test_chatbot_real_data__no_docs_found(vector_store_path):
buster_cfg = copy.deepcopy(buster_cfg_template)
buster_cfg.retriever_cfg = {
"path": vector_store_path,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
"top_k": 3,
"thresh": 1, # Set threshold very high to be sure no docs are matched
"max_tokens": 3000,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_write_read(tmp_path, documents_manager, retriever):
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
dm_path = tmp_path / "tmp_dir_2"
retriever_cfg["path"] = dm_path
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_write_write_read(tmp_path, documents_manager, retriever):
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
db_path = tmp_path / "tmp_dir"
retriever_cfg["path"] = db_path
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_generate_embeddings(tmp_path, monkeypatch):
"top_k": 3,
"thresh": 0.85,
"max_tokens": 3000,
"embedding_model": "fake-embedding",
"embedding_fn": get_fake_embedding,
}
read_df = DeepLakeRetriever(**retriever_cfg).get_documents("my_source")

Expand Down