Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hbertrand committed Nov 14, 2023
1 parent b491249 commit f3fd964
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 17 deletions.
4 changes: 2 additions & 2 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd

from buster.completers import Completion, DocumentAnswerer, UserInputs
from buster.llm_utils import QuestionReformulator
from buster.llm_utils import QuestionReformulator, get_openai_embedding
from buster.retriever import Retriever
from buster.validators import Validator

Expand Down Expand Up @@ -37,7 +37,7 @@ class BusterConfig:
"max_tokens": 3000,
"top_k": 3,
"thresh": 0.7,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
)
prompt_formatter_cfg: dict = field(
Expand Down
5 changes: 1 addition & 4 deletions buster/documents_manager/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ def _add_documents(self, df: pd.DataFrame):

to_upsert.append(vector)

if use_sparse_vector:
MAX_PINECONE_BATCH_SIZE = 100
else:
MAX_PINECONE_BATCH_SIZE = 1000
MAX_PINECONE_BATCH_SIZE = 100 if use_sparse_vector else 1000
for i in range(0, len(to_upsert), MAX_PINECONE_BATCH_SIZE):
self.index.upsert(vectors=to_upsert[i : i + MAX_PINECONE_BATCH_SIZE], namespace=self.namespace)

Expand Down
2 changes: 1 addition & 1 deletion buster/retriever/deeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_topk_documents(
If no matches are found, returns an empty dataframe."""

if query is not None:
query_embedding = self.get_embedding(query, model=self.embedding_model)
query_embedding = self.get_embedding(query)
elif embedding is not None:
query_embedding = embedding
else:
Expand Down
6 changes: 1 addition & 5 deletions buster/retriever/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,7 @@ def get_topk_documents(self, query: str, sources: Optional[list[str]], top_k: in
return pd.DataFrame()

query_embedding = self.get_embedding(query)

if self.get_sparse_embedding is not None:
sparse_query_embedding = self.get_sparse_embedding(query)
else:
sparse_query_embedding = None
sparse_query_embedding = self.get_sparse_embedding(query) if self.get_sparse_embedding is not None else None

if isinstance(query_embedding, np.ndarray):
# pinecone expects a list of floats, so convert from ndarray if necessary
Expand Down
5 changes: 3 additions & 2 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from buster.documents_manager import DeepLakeDocumentsManager
from buster.formatters.documents import DocumentsFormatterHTML
from buster.formatters.prompts import PromptFormatter
from buster.llm_utils import get_openai_embedding
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers.gpt import GPTTokenizer
from buster.validators import QuestionAnswerValidator, Validator
Expand Down Expand Up @@ -46,7 +47,7 @@
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
},
prompt_formatter_cfg={
"max_tokens": 3500,
Expand Down Expand Up @@ -241,7 +242,7 @@ def test_chatbot_real_data__no_docs_found(vector_store_path):
buster_cfg = copy.deepcopy(buster_cfg_template)
buster_cfg.retriever_cfg = {
"path": vector_store_path,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
"top_k": 3,
"thresh": 1, # Set threshold very high to be sure no docs are matched
"max_tokens": 3000,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_write_read(tmp_path, documents_manager, retriever):
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
dm_path = tmp_path / "tmp_dir_2"
retriever_cfg["path"] = dm_path
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_write_write_read(tmp_path, documents_manager, retriever):
"top_k": 3,
"thresh": 0.7,
"max_tokens": 2000,
"embedding_model": "text-embedding-ada-002",
"embedding_fn": get_openai_embedding,
}
db_path = tmp_path / "tmp_dir"
retriever_cfg["path"] = db_path
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_generate_embeddings(tmp_path, monkeypatch):
"top_k": 3,
"thresh": 0.85,
"max_tokens": 3000,
"embedding_model": "fake-embedding",
"embedding_fn": get_fake_embedding,
}
read_df = DeepLakeRetriever(**retriever_cfg).get_documents("my_source")

Expand Down

0 comments on commit f3fd964

Please sign in to comment.