From 293a169afa571ac0d309bf5085e4d8d8c37c5ff5 Mon Sep 17 00:00:00 2001 From: Traun Leyden Date: Mon, 4 Dec 2023 22:08:54 +0100 Subject: [PATCH] Trying to appease linter --- .../reading_comprehension_generation/utils.py | 2 +- .../test_utils.py | 41 +++++++++++++------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/dalm/datasets/reading_comprehension_generation/utils.py b/dalm/datasets/reading_comprehension_generation/utils.py index 5958dba..c04f150 100644 --- a/dalm/datasets/reading_comprehension_generation/utils.py +++ b/dalm/datasets/reading_comprehension_generation/utils.py @@ -172,7 +172,7 @@ def extract_question_or_answer(text: str, extract_type: str = "question") -> Tup extraction_regex = rf".*\[?{extract_type}[:\]]*(?:.*?\])?\s*(.*)" match = re.match(extraction_regex, text, re.IGNORECASE) - extracted_text = match.group(1) if match else None + extracted_text = match.group(1) if match else "" found_extracted = True if extracted_text else False return found_extracted, extracted_text diff --git a/tests/datasets/reading_comprehension_generation/test_utils.py b/tests/datasets/reading_comprehension_generation/test_utils.py index e485ada..a6567d4 100644 --- a/tests/datasets/reading_comprehension_generation/test_utils.py +++ b/tests/datasets/reading_comprehension_generation/test_utils.py @@ -1,7 +1,8 @@ from dalm.datasets.reading_comprehension_generation.utils import _raw_question_and_answer_extractor, question_and_answer_extractor -import pdb +from typing import Dict, Iterator, List, Optional, Tuple -def test_question_and_answer_extractor(): + +def test_question_and_answer_extractor() -> None: chat_completions = question_and_answer_extractor( whole_text=""" 1. QUESTION: Can you summarize the purpose of population imaging studies and how they contribute to preventing or treating disease? @@ -22,6 +23,8 @@ def test_question_and_answer_extractor(): ) print(chat_completions) + assert chat_completions is not None + # The first chat completion item should be a user prompt, and it should start with "Based on the following text:" assert chat_completions[0]["content"].startswith("Based on the following text:") assert chat_completions[0]["role"] == "user" @@ -46,9 +49,9 @@ def test_question_and_answer_extractor(): -def test_raw_question_and_answer_extractor(): +def test_raw_question_and_answer_extractor() -> None: - inputs = [ + test_cases = [ { "whole_text": """ QUESTION: What is the focus? @@ -182,14 +185,28 @@ def test_raw_question_and_answer_extractor(): ] - for input in inputs: - result_qa_pairs = _raw_question_and_answer_extractor(whole_text=input["whole_text"]) - expected_qa_pairs = input["expected_output"] - for result, expected in zip(result_qa_pairs, expected_qa_pairs): - result_question = result["question"].strip().lower() - expected_question = expected["question"].strip().lower() - result_answer = result["answer"].strip().lower() - expected_answer = expected["answer"].strip().lower() + for test_case in test_cases: + + result_qa_pairs = _raw_question_and_answer_extractor( + whole_text=str(test_case["whole_text"]) + ) + + expected_qa_pairs = test_case["expected_output"] + + assert result_qa_pairs is not None + assert expected_qa_pairs is not None + + assert len(result_qa_pairs) == len(expected_qa_pairs) + + for i, result_qa_pair in enumerate(result_qa_pairs): + expected_qa_pair = expected_qa_pairs[i] + assert result_qa_pair is not None + assert expected_qa_pair is not None + + result_question = result_qa_pair["question"].strip().lower() + expected_question = expected_qa_pair["question"].strip().lower() + result_answer = result_qa_pair["answer"].strip().lower() + expected_answer = expected_qa_pair["answer"].strip().lower() assert result_question == expected_question, f"result_question: {result_question} != expected_question: {expected_question}" assert result_answer == expected_answer, f"result_answer: {result_answer} != expected_answer: {expected_answer}"