Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jerpint committed Nov 14, 2023
1 parent f1ea8da commit 570c472
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 30 deletions.
6 changes: 1 addition & 5 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ class BusterConfig:

validator_cfg: dict = field(
default_factory=lambda: {
"unknown_prompts": [
"I Don't know how to answer your question.",
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"validate_documents": False,
}
)
tokenizer_cfg: dict = field(
Expand Down
2 changes: 1 addition & 1 deletion buster/validators/question_answer_validator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import concurrent.futures
import logging
from typing import Callable, List, Optional
import numpy as np

import numpy as np
import pandas as pd

from buster.completers import ChatGPTCompleter, Completer
Expand Down
33 changes: 21 additions & 12 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from buster.formatters.prompts import PromptFormatter
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers.gpt import GPTTokenizer
from buster.validators import QuestionAnswerValidator, Validator
from buster.validators import Validator

logging.basicConfig(level=logging.INFO)

Expand All @@ -32,14 +32,23 @@
},
},
validator_cfg={
"unknown_response_templates": [
UNKNOWN_PROMPT,
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"validate_documents": False,
"use_reranking": True,
"check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.",
"completion_kwargs": {"temperature": 0, "model": "gpt-3.5-turbo"},
"answer_validator_cfg": {
"unknown_response_templates": [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
},
"question_validator_cfg": {
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.",
},
},
retriever_cfg={
# "db_path": to be set using pytest fixture,
Expand Down Expand Up @@ -129,7 +138,7 @@ def get_source_display_name(self, source):
return source


class MockValidator(Validator):
class MockValidator:
def __init__(self, *args, **kwargs):
return

Expand Down Expand Up @@ -187,7 +196,7 @@ def test_chatbot_real_data__chatGPT(vector_store_path):
documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator)

completion = buster.process_input("What is backpropagation?")
Expand Down Expand Up @@ -224,7 +233,7 @@ def test_chatbot_real_data__chatGPT_OOD(vector_store_path):
documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg),
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator)

completion: Completion = buster.process_input("What is a good recipe for brocolli soup?")
Expand Down Expand Up @@ -255,7 +264,7 @@ def test_chatbot_real_data__no_docs_found(vector_store_path):
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
**buster_cfg.documents_answerer_cfg,
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator)

completion = buster.process_input("What is backpropagation?")
Expand Down
31 changes: 19 additions & 12 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import pandas as pd

from buster.llm_utils import get_openai_embedding
from buster.validators import QuestionAnswerValidator, Validator
from buster.validators import Validator

validator_cfg = {
"unknown_response_templates": [
"I Don't know how to answer your question.",
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.",
"completion_kwargs": {"temperature": 0, "model": "gpt-3.5-turbo"},
"validate_documents": True,
"answer_validator_cfg": {
"unknown_response_templates": [
"I Don't know how to answer your question.",
],
"unknown_threshold": 0.85,
},
"question_validator_cfg": {
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.",
},
}
validator = QuestionAnswerValidator(**validator_cfg)
validator = Validator(**validator_cfg)


def test_validator_check_question_relevance():
Expand Down Expand Up @@ -41,9 +50,7 @@ def test_validator_rerank_docs():
"A green apple on the counter",
]
matched_documents = pd.DataFrame({"documents": documents})
matched_documents["embedding"] = matched_documents.documents.apply(
lambda x: get_openai_embedding(x, model=validator.embedding_model)
)
matched_documents["embedding"] = matched_documents.documents.apply(lambda x: get_openai_embedding(x))

answer = "An apple is a delicious fruit."
reranked_documents = validator.rerank_docs(answer, matched_documents)
Expand Down

0 comments on commit 570c472

Please sign in to comment.