Skip to content

Commit

Permalink
hybrid retrieval (#152)
Browse files Browse the repository at this point in the history
* hybrid retrieval

* tests

* openai retriever

* fix deeplake retriever
  • Loading branch information
hbertrand authored Dec 7, 2023
1 parent 13ab209 commit db7db50
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 32 deletions.
6 changes: 3 additions & 3 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Optional

import pandas as pd

from buster.completers import Completion, DocumentAnswerer, UserInputs
from buster.llm_utils import QuestionReformulator
from buster.llm_utils import QuestionReformulator, get_openai_embedding
from buster.retriever import Retriever
from buster.validators import Validator

Expand Down Expand Up @@ -33,7 +33,7 @@ class BusterConfig:
"max_tokens": 3000,
"top_k": 3,
"thresh": 0.7,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
)
prompt_formatter_cfg: dict = field(
Expand Down
16 changes: 13 additions & 3 deletions buster/documents_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
from typing import Callable, Optional

import numpy as np
import pandas as pd
from tqdm import tqdm

Expand Down Expand Up @@ -63,7 +64,8 @@ def add(
self,
df: pd.DataFrame,
num_workers: int = 16,
embedding_fn: callable = get_openai_embedding,
embedding_fn: Callable[[str], np.ndarray] = get_openai_embedding,
sparse_embedding_fn: Callable[[str], dict[str, list[float]]] = None,
csv_filename: Optional[str] = None,
csv_overwrite: bool = True,
**add_kwargs,
Expand All @@ -81,6 +83,8 @@ def add(
num_workers (int, optional): The number of parallel workers to use for computing embeddings. Default is 32.
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string.
Default is None. Only use if you want sparse embeddings.
csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to True.
**add_kwargs: Additional keyword arguments to be passed to the '_add_documents' method.
Expand All @@ -92,6 +96,8 @@ def add(
# Check if embeddings are present, computes them if not
if "embedding" not in df.columns:
df["embedding"] = compute_embeddings_parallelized(df, embedding_fn=embedding_fn, num_workers=num_workers)
if "sparse_embedding" not in df.columns and sparse_embedding_fn is not None:
df["sparse_embedding"] = sparse_embedding_fn(df.content.to_list())

if csv_filename is not None:
self._checkpoint_csv(df, csv_filename=csv_filename, csv_overwrite=csv_overwrite)
Expand All @@ -104,7 +110,8 @@ def batch_add(
batch_size: int = 3000,
min_time_interval: int = 60,
num_workers: int = 16,
embedding_fn: callable = get_openai_embedding,
embedding_fn: Callable[[str], np.ndarray] = get_openai_embedding,
sparse_embedding_fn: Callable[[str], dict[str, list[float]]] = None,
csv_filename: Optional[str] = None,
csv_overwrite: bool = False,
**add_kwargs,
Expand All @@ -125,6 +132,8 @@ def batch_add(
Defaults to 32.
embedding_fn (callable, optional): A function that computes embeddings for a given input string.
Default is 'get_embedding_openai' which uses the text-embedding-ada-002 model.
sparse_embedding_fn (callable, optional): A function that computes sparse embeddings for a given input string.
Default is None. Only use if you want sparse embeddings.
csv_filename (str, optional): Path to save a copy of the dataframe with computed embeddings for later use.
csv_overwrite (bool, optional): Whether to overwrite the file with a new file. Defaults to False.
When using batches, set to False to keep all embeddings in the same file. You may want to manually remove the file if experimenting.
Expand All @@ -151,6 +160,7 @@ def batch_add(
csv_filename=csv_filename,
csv_overwrite=csv_overwrite,
embedding_fn=embedding_fn,
sparse_embedding_fn=sparse_embedding_fn,
**add_kwargs,
)

Expand Down
24 changes: 23 additions & 1 deletion buster/documents_manager/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def _add_documents(self, df: pd.DataFrame):
Args:
df: The dataframe containing the documents.
"""
use_sparse_vector = "sparse_embedding" in df.columns
if use_sparse_vector:
logger.info("Uploading sparse embeddings too.")

for source in df.source.unique():
source_exists = self.db.sources.find_one({"name": source})
if source_exists is None:
Expand All @@ -75,14 +79,32 @@ def _add_documents(self, df: pd.DataFrame):
source_id = self.get_source_id(source)

df_source = df[df.source == source]
to_upsert = []
for row in df_source.to_dict(orient="records"):
embedding = row["embedding"].tolist()
if use_sparse_vector:
sparse_embedding = row["sparse_embedding"]

document = row.copy()
document.pop("embedding")
if use_sparse_vector:
document.pop("sparse_embedding")
document["source_id"] = source_id

document_id = str(self.db.documents.insert_one(document).inserted_id)
self.index.upsert([(document_id, embedding, {"source": source})], namespace=self.namespace)
vector = {"id": document_id, "values": embedding, "metadata": {"source": source}}
if use_sparse_vector:
vector["sparse_values"] = sparse_embedding

to_upsert.append(vector)

# Current (November 2023) Pinecone upload rules:
# - Max 1000 vectors per batch
# - Max 2 MB per batch
# Sparse vectors are heavier, so we reduce the batch size when using them.
MAX_PINECONE_BATCH_SIZE = 100 if use_sparse_vector else 1000
for i in range(0, len(to_upsert), MAX_PINECONE_BATCH_SIZE):
self.index.upsert(vectors=to_upsert[i : i + MAX_PINECONE_BATCH_SIZE], namespace=self.namespace)

def update_source(self, source: str, display_name: str = None, note: str = None):
"""Update the display name and/or note of a source. Also create the source if it does not exist.
Expand Down
2 changes: 2 additions & 0 deletions buster/llm_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from buster.llm_utils.embeddings import (
BM25,
compute_embeddings_parallelized,
cosine_similarity,
get_openai_embedding,
Expand All @@ -12,4 +13,5 @@
get_openai_embedding,
compute_embeddings_parallelized,
get_openai_embedding_constructor,
BM25,
]
21 changes: 21 additions & 0 deletions buster/llm_utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
from openai import OpenAI
from pinecone_text.sparse import BM25Encoder
from tqdm.contrib.concurrent import thread_map

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,3 +65,23 @@ def compute_embeddings_parallelized(df: pd.DataFrame, embedding_fn: callable, nu

logger.info(f"Finished computing embeddings")
return embeddings


class BM25:
def __init__(self, path_to_params: str = None) -> None:
self.encoder = BM25Encoder()

if path_to_params:
self.encoder.load(path_to_params)

def fit(self, df: pd.DataFrame):
self.encoder.fit(df.content.to_list())

def dump_params(self, path: str):
self.encoder.dump(path)

def get_sparse_embedding_fn(self):
def sparse_embedding_fn(query: str):
return self.encoder.encode_queries(query)

return sparse_embedding_fn
25 changes: 11 additions & 14 deletions buster/retriever/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from typing import Callable, Optional

import numpy as np
Expand All @@ -18,13 +17,22 @@

@dataclass
class Retriever(ABC):
def __init__(self, top_k: int, thresh: float, embedding_fn: Callable[[str], np.array] = None, *args, **kwargs):
def __init__(
self,
top_k: int,
thresh: float,
embedding_fn: Callable[[str], np.ndarray] = None,
sparse_embedding_fn: Callable[[str], dict[str, list[float]]] = None,
*args,
**kwargs,
):
"""Initializes a Retriever instance.
Args:
top_k: The maximum number of documents to retrieve.
thresh: The similarity threshold for document retrieval.
embedding_fn: The function to compute document embeddings.
embedding_fn: (Optional) The function to compute sparse document embeddings.
*args, **kwargs: Additional arguments and keyword arguments.
"""
if embedding_fn is None:
Expand All @@ -33,6 +41,7 @@ def __init__(self, top_k: int, thresh: float, embedding_fn: Callable[[str], np.a
self.top_k = top_k
self.thresh = thresh
self.embedding_fn = embedding_fn
self.sparse_embedding_fn = sparse_embedding_fn

# Add your access to documents in your own init

Expand Down Expand Up @@ -62,18 +71,6 @@ def get_source_display_name(self, source: str) -> str:
"""
...

def get_embedding(self, query: str) -> np.ndarray:
"""Generates the embedding of a query.
Args:
query: The query for which to generate the embedding.
Returns:
The embedding of the query as a NumPy array.
"""
logger.info("generating embedding")
return self.embedding_fn(query)

@abstractmethod
def get_topk_documents(self, query: str, source: Optional[str] = None, top_k: Optional[int] = None) -> pd.DataFrame:
"""Get the topk documents matching a user's query.
Expand Down
2 changes: 1 addition & 1 deletion buster/retriever/deeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_topk_documents(
The DataFrame containing the matched documents.
"""
if query is not None:
query_embedding = self.get_embedding(query)
query_embedding = self.embedding_fn(query)
elif embedding is not None:
query_embedding = embedding
else:
Expand Down
10 changes: 8 additions & 2 deletions buster/retriever/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,21 @@ def get_topk_documents(self, query: str, sources: Optional[List[str]], top_k: in
logger.warning(f"Sources {sources} do not exist. Returning empty dataframe.")
return pd.DataFrame()

query_embedding = self.get_embedding(query)
query_embedding = self.embedding_fn(query)
sparse_query_embedding = self.sparse_embedding_fn(query) if self.sparse_embedding_fn is not None else None

if isinstance(query_embedding, np.ndarray):
# pinecone expects a list of floats, so convert from ndarray if necessary
query_embedding = query_embedding.tolist()

# Pinecone retrieval
matches = self.index.query(
query_embedding, top_k=top_k, filter=filter, include_values=True, namespace=self.namespace
vector=query_embedding,
sparse_vector=sparse_query_embedding,
top_k=top_k,
filter=filter,
include_values=True,
namespace=self.namespace,
)["matches"]
matching_ids = [ObjectId(match.id) for match in matches]
matching_scores = {match.id: match.score for match in matches}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ numpy
openai>=1.0
pandas
pinecone-client
pinecone-text
pymongo
pytest
tabulate
Expand Down
10 changes: 5 additions & 5 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from buster.documents_manager import DeepLakeDocumentsManager
from buster.formatters.documents import DocumentsFormatterHTML
from buster.formatters.prompts import PromptFormatter
from buster.llm_utils import get_openai_embedding
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers.gpt import GPTTokenizer
from buster.validators import Validator
Expand Down Expand Up @@ -63,7 +64,7 @@
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
},
prompt_formatter_cfg={
"max_tokens": 3500,
Expand Down Expand Up @@ -130,6 +131,8 @@ def __init__(self, **kwargs):
}
)

self.embedding_fn = get_fake_embedding

def get_documents(self, source):
return self.documents

Expand All @@ -139,9 +142,6 @@ def get_topk_documents(self, query: str, sources: list[str] = None, top_k: int =
documents["similarity"] = [np.random.random() for _ in range(len(documents))]
return documents

def get_embedding(self, query, engine):
return get_fake_embedding()

def get_source_display_name(self, source):
return source

Expand Down Expand Up @@ -258,7 +258,7 @@ def test_chatbot_real_data__no_docs_found(vector_store_path):
buster_cfg = copy.deepcopy(buster_cfg_template)
buster_cfg.retriever_cfg = {
"path": vector_store_path,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
"top_k": 3,
"thresh": 1, # Set threshold very high to be sure no docs are matched
"max_tokens": 3000,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_write_read(tmp_path, documents_manager, retriever):
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
dm_path = tmp_path / "tmp_dir_2"
retriever_cfg["path"] = dm_path
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_write_write_read(tmp_path, documents_manager, retriever):
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
db_path = tmp_path / "tmp_dir"
retriever_cfg["path"] = db_path
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_generate_embeddings(tmp_path, monkeypatch):
"top_k": 3,
"thresh": 0.85,
"max_tokens": 3000,
"embedding_model": "fake-embedding",
"embedding_fn": get_fake_embedding,
}
read_df = DeepLakeRetriever(**retriever_cfg).get_documents("my_source")

Expand Down

0 comments on commit db7db50

Please sign in to comment.