diff --git a/buster/busterbot.py b/buster/busterbot.py index 638c28c..88a066c 100644 --- a/buster/busterbot.py +++ b/buster/busterbot.py @@ -19,12 +19,8 @@ class BusterConfig: validator_cfg: dict = field( default_factory=lambda: { - "unknown_prompts": [ - "I Don't know how to answer your question.", - ], - "unknown_threshold": 0.85, - "embedding_model": "text-embedding-ada-002", "use_reranking": True, + "validate_documents": False, } ) tokenizer_cfg: dict = field( diff --git a/buster/validators/question_answer_validator.py b/buster/validators/question_answer_validator.py index 21b5118..de32a19 100644 --- a/buster/validators/question_answer_validator.py +++ b/buster/validators/question_answer_validator.py @@ -1,8 +1,8 @@ import concurrent.futures import logging from typing import Callable, List, Optional -import numpy as np +import numpy as np import pandas as pd from buster.completers import ChatGPTCompleter, Completer diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index e92fb97..a074b5b 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -14,7 +14,7 @@ from buster.formatters.prompts import PromptFormatter from buster.retriever import DeepLakeRetriever, Retriever from buster.tokenizers.gpt import GPTTokenizer -from buster.validators import QuestionAnswerValidator, Validator +from buster.validators import Validator logging.basicConfig(level=logging.INFO) @@ -32,14 +32,23 @@ }, }, validator_cfg={ - "unknown_response_templates": [ - UNKNOWN_PROMPT, - ], - "unknown_threshold": 0.85, - "embedding_model": "text-embedding-ada-002", + "validate_documents": False, "use_reranking": True, - "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", - "completion_kwargs": {"temperature": 0, "model": "gpt-3.5-turbo"}, + "answer_validator_cfg": { + "unknown_response_templates": [ + "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?", + ], + "unknown_threshold": 0.85, + }, + "question_validator_cfg": { + "invalid_question_response": "This question does not seem relevant to my current knowledge.", + "completion_kwargs": { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + }, + "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", + }, }, retriever_cfg={ # "db_path": to be set using pytest fixture, @@ -129,7 +138,7 @@ def get_source_display_name(self, source): return source -class MockValidator(Validator): +class MockValidator: def __init__(self, *args, **kwargs): return @@ -187,7 +196,7 @@ def test_chatbot_real_data__chatGPT(vector_store_path): documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg), prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg), ) - validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg) + validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator) completion = buster.process_input("What is backpropagation?") @@ -224,7 +233,7 @@ def test_chatbot_real_data__chatGPT_OOD(vector_store_path): documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg), prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg), ) - validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg) + validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator) completion: Completion = buster.process_input("What is a good recipe for brocolli soup?") @@ -255,7 +264,7 @@ def test_chatbot_real_data__no_docs_found(vector_store_path): prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg), **buster_cfg.documents_answerer_cfg, ) - validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg) + validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator) completion = buster.process_input("What is backpropagation?") diff --git a/tests/test_validator.py b/tests/test_validator.py index 718d33d..ff1dc75 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -1,19 +1,28 @@ import pandas as pd from buster.llm_utils import get_openai_embedding -from buster.validators import QuestionAnswerValidator, Validator +from buster.validators import Validator validator_cfg = { - "unknown_response_templates": [ - "I Don't know how to answer your question.", - ], - "unknown_threshold": 0.85, - "embedding_model": "text-embedding-ada-002", "use_reranking": True, - "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", - "completion_kwargs": {"temperature": 0, "model": "gpt-3.5-turbo"}, + "validate_documents": True, + "answer_validator_cfg": { + "unknown_response_templates": [ + "I Don't know how to answer your question.", + ], + "unknown_threshold": 0.85, + }, + "question_validator_cfg": { + "invalid_question_response": "This question does not seem relevant to my current knowledge.", + "completion_kwargs": { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + }, + "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", + }, } -validator = QuestionAnswerValidator(**validator_cfg) +validator = Validator(**validator_cfg) def test_validator_check_question_relevance(): @@ -41,9 +50,7 @@ def test_validator_rerank_docs(): "A green apple on the counter", ] matched_documents = pd.DataFrame({"documents": documents}) - matched_documents["embedding"] = matched_documents.documents.apply( - lambda x: get_openai_embedding(x, model=validator.embedding_model) - ) + matched_documents["embedding"] = matched_documents.documents.apply(lambda x: get_openai_embedding(x)) answer = "An apple is a delicious fruit." reranked_documents = validator.rerank_docs(answer, matched_documents)