Skip to content

Commit

Permalink
Trying to appease linter
Browse files Browse the repository at this point in the history
  • Loading branch information
tleyden committed Dec 4, 2023
1 parent 9dca99f commit 293a169
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
2 changes: 1 addition & 1 deletion dalm/datasets/reading_comprehension_generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def extract_question_or_answer(text: str, extract_type: str = "question") -> Tup
extraction_regex = rf".*\[?{extract_type}[:\]]*(?:.*?\])?\s*(.*)"

match = re.match(extraction_regex, text, re.IGNORECASE)
extracted_text = match.group(1) if match else None
extracted_text = match.group(1) if match else ""
found_extracted = True if extracted_text else False
return found_extracted, extracted_text

Expand Down
41 changes: 29 additions & 12 deletions tests/datasets/reading_comprehension_generation/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dalm.datasets.reading_comprehension_generation.utils import _raw_question_and_answer_extractor, question_and_answer_extractor
import pdb
from typing import Dict, Iterator, List, Optional, Tuple

def test_question_and_answer_extractor():

def test_question_and_answer_extractor() -> None:
chat_completions = question_and_answer_extractor(
whole_text="""
1. QUESTION: Can you summarize the purpose of population imaging studies and how they contribute to preventing or treating disease?
Expand All @@ -22,6 +23,8 @@ def test_question_and_answer_extractor():
)
print(chat_completions)

assert chat_completions is not None

# The first chat completion item should be a user prompt, and it should start with "Based on the following text:"
assert chat_completions[0]["content"].startswith("Based on the following text:")
assert chat_completions[0]["role"] == "user"
Expand All @@ -46,9 +49,9 @@ def test_question_and_answer_extractor():



def test_raw_question_and_answer_extractor():
def test_raw_question_and_answer_extractor() -> None:

inputs = [
test_cases = [
{
"whole_text": """
QUESTION: What is the focus?
Expand Down Expand Up @@ -182,14 +185,28 @@ def test_raw_question_and_answer_extractor():

]

for input in inputs:
result_qa_pairs = _raw_question_and_answer_extractor(whole_text=input["whole_text"])
expected_qa_pairs = input["expected_output"]
for result, expected in zip(result_qa_pairs, expected_qa_pairs):
result_question = result["question"].strip().lower()
expected_question = expected["question"].strip().lower()
result_answer = result["answer"].strip().lower()
expected_answer = expected["answer"].strip().lower()
for test_case in test_cases:

result_qa_pairs = _raw_question_and_answer_extractor(
whole_text=str(test_case["whole_text"])
)

expected_qa_pairs = test_case["expected_output"]

assert result_qa_pairs is not None
assert expected_qa_pairs is not None

assert len(result_qa_pairs) == len(expected_qa_pairs)

for i, result_qa_pair in enumerate(result_qa_pairs):
expected_qa_pair = expected_qa_pairs[i]
assert result_qa_pair is not None
assert expected_qa_pair is not None

result_question = result_qa_pair["question"].strip().lower()
expected_question = expected_qa_pair["question"].strip().lower()
result_answer = result_qa_pair["answer"].strip().lower()
expected_answer = expected_qa_pair["answer"].strip().lower()
assert result_question == expected_question, f"result_question: {result_question} != expected_question: {expected_question}"
assert result_answer == expected_answer, f"result_answer: {result_answer} != expected_answer: {expected_answer}"

Expand Down

0 comments on commit 293a169

Please sign in to comment.