Skip to content

Commit

Permalink
Add docuemnts evaluator, refactor validators (#151)
Browse files Browse the repository at this point in the history
* add documents validator

* refactor validators

* Add sane defaults

* fix tests
  • Loading branch information
jerpint authored Nov 14, 2023
1 parent 008e2ba commit 64ba998
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 162 deletions.
6 changes: 1 addition & 5 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ class BusterConfig:

validator_cfg: dict = field(
default_factory=lambda: {
"unknown_prompts": [
"I Don't know how to answer your question.",
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"validate_documents": False,
}
)
tokenizer_cfg: dict = field(
Expand Down
5 changes: 5 additions & 0 deletions buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def postprocess(self):
answer=self.answer_text, matched_documents=self.matched_documents
)

if self.validator.validate_documents:
self.matched_documents = self.validator.check_documents_relevance(
answer=self.answer_text, matched_documents=self.matched_documents
)

# access the property so it gets set if not computed alerady
self.answer_relevant

Expand Down
40 changes: 25 additions & 15 deletions buster/examples/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
from buster.formatters.prompts import PromptFormatter
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers import GPTTokenizer
from buster.validators import QuestionAnswerValidator, Validator
from buster.validators import Validator

buster_cfg = BusterConfig(
validator_cfg={
"unknown_response_templates": [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"check_question_prompt": """You are an chatbot answering questions on artificial intelligence.
"question_validator_cfg": {
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"check_question_prompt": """You are a chatbot answering questions on artificial intelligence.
Your job is to determine wether or not a question is valid, and should be answered.
More general questions are not considered valid, even if you might know the response.
A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.
Expand All @@ -30,11 +29,22 @@
false
A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"answer_validator_cfg": {
"unknown_response_templates": [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
},
"documents_validator_cfg": {
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
},
"use_reranking": True,
"validate_documents": True,
},
retriever_cfg={
"path": "deeplake_store",
Expand Down Expand Up @@ -98,6 +108,6 @@ def setup_buster(buster_cfg: BusterConfig):
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
**buster_cfg.documents_answerer_cfg,
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator)
return buster
4 changes: 3 additions & 1 deletion buster/llm_utils/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import lru_cache

import numpy as np
import pandas as pd
Expand All @@ -11,7 +12,8 @@
client = OpenAI()


def get_openai_embedding(text: str, model: str = "text-embedding-ada-002"):
@lru_cache
def get_openai_embedding(text: str, model: str = "text-embedding-ada-002") -> np.array:
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
Expand Down
3 changes: 1 addition & 2 deletions buster/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .base import Validator
from .question_answer_validator import QuestionAnswerValidator

__all__ = [Validator, QuestionAnswerValidator]
__all__ = [Validator]
57 changes: 32 additions & 25 deletions buster/validators/base.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,53 @@
import logging
from abc import ABC, abstractmethod
from functools import lru_cache

import pandas as pd

from buster.llm_utils import cosine_similarity, get_openai_embedding
from buster.validators.validators import (
AnswerValidator,
DocumentsValidator,
QuestionValidator,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class Validator(ABC):
class Validator:
def __init__(
self,
embedding_model: str,
unknown_threshold: float,
use_reranking: bool,
invalid_question_response: str = "This question is not relevant to my internal knowledge base.",
validate_documents: bool,
question_validator_cfg=None,
answer_validator_cfg=None,
documents_validator_cfg=None,
):
self.embedding_model = embedding_model
self.unknown_threshold = unknown_threshold
self.question_validator = (
QuestionValidator(**question_validator_cfg) if question_validator_cfg is not None else QuestionValidator()
)
self.answer_validator = (
AnswerValidator(**answer_validator_cfg) if answer_validator_cfg is not None else AnswerValidator()
)
self.documents_validator = (
DocumentsValidator(**documents_validator_cfg)
if documents_validator_cfg is not None
else DocumentsValidator()
)
self.use_reranking = use_reranking
self.invalid_question_response = invalid_question_response

@staticmethod
@lru_cache
def get_embedding(text: str, model: str):
"""Currently supports OpenAI embeddings, override to add your own."""
logger.info("generating embedding")
return get_openai_embedding(text, model)
self.validate_documents = validate_documents

@abstractmethod
def check_question_relevance(self, question: str) -> tuple[bool, str]:
...
return self.question_validator.check_question_relevance(question)

@abstractmethod
def check_answer_relevance(self, answer: str) -> bool:
...
return self.answer_validator.check_answer_relevance(answer)

def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame:
def check_documents_relevance(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame:
return self.documents_validator.check_documents_relevance(answer, matched_documents)

def rerank_docs(
self, answer: str, matched_documents: pd.DataFrame, embedding_fn=get_openai_embedding
) -> pd.DataFrame:
"""Here we re-rank matched documents according to the answer provided by the llm.
This score could be used to determine wether a document was actually relevant to generation.
Expand All @@ -48,10 +57,8 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr
return matched_documents
logger.info("Reranking documents based on answer similarity...")

answer_embedding = self.get_embedding(
answer,
model=self.embedding_model,
)
answer_embedding = embedding_fn(answer)

col = "similarity_to_answer"
matched_documents[col] = matched_documents.embedding.apply(lambda x: cosine_similarity(x, answer_embedding))

Expand Down
90 changes: 0 additions & 90 deletions buster/validators/question_answer_validator.py

This file was deleted.

Loading

0 comments on commit 64ba998

Please sign in to comment.