Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSK-3609 Avoid redundant questions in data generation #1990

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions giskard/rag/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ def get_failure_plot(self, question_evaluation: Sequence[dict] = None):
def get_random_document(self):
return self._rng.choice(self._documents)

def get_random_documents(self, n: int, with_replacement=False):
if with_replacement:
return list(self._rng.choice(self._documents, n, replace=True))

docs = list(self._rng.choice(self._documents, min(n, len(self._documents)), replace=False))

if len(docs) <= n:
docs.extend(self._rng.choice(self._documents, n - len(docs), replace=True))

return docs

def get_neighbors(self, seed_document: Document, n_neighbors: int = 4, similarity_threshold: float = 0.2):
seed_embedding = seed_document.embeddings

Expand Down
6 changes: 4 additions & 2 deletions giskard/rag/question_generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ class GenerateFromSingleQuestionMixin:
_question_type: str

def generate_questions(self, knowledge_base: KnowledgeBase, num_questions: int, *args, **kwargs) -> Iterator[Dict]:
for _ in range(num_questions):
docs = knowledge_base.get_random_documents(num_questions)

for doc in docs:
try:
yield self.generate_single_question(knowledge_base, *args, **kwargs)
yield self.generate_single_question(knowledge_base, *args, **kwargs, seed_document=doc)
except Exception as e: # @TODO: specify exceptions
logger.error(f"Encountered error in question generation: {e}. Skipping.")
logger.exception(e)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/double_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class DoubleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio
_question_type = "double"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()
context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/oos_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OutOfScopeGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestionGene
_question_type = "out of scope"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
"""
Generate a question from a list of context documents.
Expand All @@ -87,7 +87,7 @@ def generate_single_question(
Tuple[dict, dict]
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
Expand Down
11 changes: 9 additions & 2 deletions giskard/rag/question_generators/simple_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ class SimpleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio

_question_type = "simple"

def generate_single_question(self, knowledge_base: KnowledgeBase, agent_description: str, language: str) -> dict:
def generate_single_question(
self,
knowledge_base: KnowledgeBase,
agent_description: str,
language: str,
seed_document=None,
) -> dict:
"""
Generate a question from a list of context documents.

Expand All @@ -80,7 +86,8 @@ def generate_single_question(self, knowledge_base: KnowledgeBase, agent_descript
QuestionSample
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
2,404 changes: 1,214 additions & 1,190 deletions pdm.lock

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions tests/rag/test_knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,31 @@
from giskard.rag.knowledge_base import KnowledgeBase


def test_knowledge_base_get_random_documents():
llm_client = Mock()
embeddings = Mock()
embeddings.embed.side_effect = [np.random.rand(5, 10), np.random.rand(3, 10)]

kb = KnowledgeBase.from_pandas(
df=pd.DataFrame({"text": ["This is a test string"] * 5}), llm_client=llm_client, embedding_model=embeddings
)

# Test when k is smaller than the number of documents
docs = kb.get_random_documents(3)
assert len(docs) == 3
assert all([doc in kb._documents for doc in docs])

# Test when k is equal to the number of documents
docs = kb.get_random_documents(5)
assert len(docs) == 5
assert all([doc in kb._documents for doc in docs])

# Test when k is larger than the number of documents
docs = kb.get_random_documents(10)
assert len(docs) == 10
assert all([doc in kb._documents for doc in docs])


def test_knowledge_base_creation_from_df():
dimension = 8
df = pd.DataFrame(["This is a test string"] * 5)
Expand Down
3 changes: 3 additions & 0 deletions tests/rag/test_question_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_simple_question_generation():
Document(dict(content="Milk is produced by cows, goats or sheep.")),
]
knowledge_base.get_random_document = Mock(return_value=documents[0])
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = SimpleQuestionsGenerator(llm_client=llm_client)
Expand Down Expand Up @@ -212,6 +213,7 @@ def test_double_question_generation():
Document(dict(content="Milk is produced by cows, goats or sheep.")),
]
knowledge_base.get_random_document = Mock(return_value=documents[0])
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = DoubleQuestionsGenerator(llm_client=llm_client)
Expand Down Expand Up @@ -304,6 +306,7 @@ def test_oos_question_generation():
dict(content="Paul Graham liked to buy a baguette every day at the local market."), doc_id="1"
)
)
knowledge_base.get_random_documents = Mock(return_value=documents)
knowledge_base.get_neighbors = Mock(return_value=documents)

question_generator = OutOfScopeGenerator(llm_client=llm_client)
Expand Down
2 changes: 2 additions & 0 deletions tests/rag/test_testset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def test_question_generation_fail(caplog):
knowledge_base.__getitem__ = lambda obj, idx: documents[0]
knowledge_base.topics = ["Cheese", "Ski"]

knowledge_base.get_random_documents = Mock(return_value=documents)

simple_gen = Mock()
simple_gen.generate_questions.return_value = [q1, q2]
failing_gen = SimpleQuestionsGenerator(llm_client=Mock())
Expand Down
Loading