Skip to content

Commit

Permalink
Add sane defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
jerpint committed Nov 14, 2023
1 parent f3e0bb3 commit f1ea8da
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
14 changes: 11 additions & 3 deletions buster/validators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,17 @@ def __init__(
answer_validator_cfg=None,
documents_validator_cfg=None,
):
self.question_validator = QuestionValidator(**question_validator_cfg)
self.answer_validator = AnswerValidator(**answer_validator_cfg)
self.documents_validator = DocumentsValidator(**documents_validator_cfg)
self.question_validator = (
QuestionValidator(**question_validator_cfg) if question_validator_cfg is not None else QuestionValidator()
)
self.answer_validator = (
AnswerValidator(**answer_validator_cfg) if answer_validator_cfg is not None else AnswerValidator()
)
self.documents_validator = (
DocumentsValidator(**documents_validator_cfg)
if documents_validator_cfg is not None
else DocumentsValidator()
)
self.use_reranking = use_reranking
self.validate_documents = validate_documents

Expand Down
72 changes: 61 additions & 11 deletions buster/validators/question_answer_validator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import concurrent.futures
import logging
from typing import Callable, List
from typing import Callable, List, Optional
import numpy as np

import pandas as pd

from buster.completers import ChatGPTCompleter
from buster.completers import ChatGPTCompleter, Completer
from buster.llm_utils import cosine_similarity
from buster.llm_utils.embeddings import get_openai_embedding

Expand All @@ -14,8 +14,44 @@


class QuestionValidator:
def __init__(self, completion_kwargs: dict, check_question_prompt: str, invalid_question_response: str):
self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs)
def __init__(
self,
check_question_prompt: Optional[str] = None,
invalid_question_response: Optional[str] = None,
completion_kwargs: Optional[dict] = None,
completer: Optional[Completer] = None,
):
if check_question_prompt is None:
check_question_prompt = (
"""You are a chatbot answering questions on documentation.
Your job is to determine wether or not a question is valid, and should be answered.
More general questions are not considered valid, even if you might know the response.
A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.
For example:
Q: What is backpropagation?
true
Q: What is the meaning of life?
false
A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""",
)

if completer is None:
completer = ChatGPTCompleter

if completion_kwargs is None:
completion_kwargs = (
{
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
)

self.completer = completer(completion_kwargs=completion_kwargs)
self.check_question_prompt = check_question_prompt
self.invalid_question_response = invalid_question_response

Expand All @@ -38,12 +74,26 @@ def check_question_relevance(self, question: str) -> tuple[bool, str]:


class AnswerValidator:
def __init__(self, unknown_response_templates: list[str], unknown_threshold: float, embedding_fn: Callable[[str], np.array] = None):
self.unknown_response_templates = unknown_response_templates
self.unknown_threshold = unknown_threshold
def __init__(
self,
unknown_response_templates: Optional[list[str]] = None,
unknown_threshold: Optional[float] = None,
embedding_fn: Callable[[str], np.array] = None,
):
if unknown_threshold is None:
unknown_threshold = 0.85

if embedding_fn is None:
self.embedding_fn = get_openai_embedding
embedding_fn = get_openai_embedding

if unknown_response_templates is None:
unknown_response_templates = [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
]

self.embedding_fn = embedding_fn
self.unknown_response_templates = unknown_response_templates
self.unknown_threshold = unknown_threshold

def check_answer_relevance(self, answer: str) -> bool:
"""Check if a generated answer is relevant to the chatbot's knowledge."""
Expand All @@ -66,9 +116,9 @@ def check_answer_relevance(self, answer: str) -> bool:
class DocumentsValidator:
def __init__(
self,
completion_kwargs: dict = None,
system_prompt: str = None,
user_input_formatter: str = None,
completion_kwargs: Optional[dict] = None,
system_prompt: Optional[str] = None,
user_input_formatter: Optional[str] = None,
max_calls: int = 30,
):
if system_prompt is None:
Expand Down

0 comments on commit f1ea8da

Please sign in to comment.