diff --git a/tests/test_step.py b/tests/test_step.py new file mode 100644 index 0000000..66592cd --- /dev/null +++ b/tests/test_step.py @@ -0,0 +1,25 @@ +STEP_FILE = """ +===== QUESTION_EXAMPLE ===== +Question example no. 1 +===== ANSWER ===== +Answer example no. 1 +===== QUESTION_EXAMPLE ===== +Question example no. 2 +===== ANSWER ===== +Answer example no. 2 + +===== QUESTION ===== +Question example no. 2 +""" + + +def test_step_file(): + from verified_cogen.runners.chain_of_thought import Step + from verified_cogen.runners.chain_of_thought.step import Substep + + step = Step(STEP_FILE) + assert step.question == "Question example no. 2\n" + assert step.substeps == [ + Substep("Question example no. 1", "Answer example no. 1"), + Substep("Question example no. 2", "Answer example no. 2\n"), + ] diff --git a/verified_cogen/runners/chain_of_thought/step.py b/verified_cogen/runners/chain_of_thought/step.py index ef3ddb0..5c2698b 100644 --- a/verified_cogen/runners/chain_of_thought/step.py +++ b/verified_cogen/runners/chain_of_thought/step.py @@ -6,6 +6,14 @@ def __init__(self, question: str, answer: str): self.question = question self.answer = answer + def __repr__(self) -> str: + return f"Substep(question={self.question}, answer={self.answer})" + + def __eq__(self, value: object, /) -> bool: + if not isinstance(value, Substep): + return False + return self.question == value.question and self.answer == value.answer + class Step: substeps: list[Substep] = [] @@ -26,10 +34,15 @@ def __init__(self, data: str): elif line.startswith("===== ANSWER ====="): appending_to = current_answer elif line.startswith("===== QUESTION ====="): + self.substeps.append( + Substep("\n".join(current_question), "\n".join(current_answer)) + ) + self.question = "\n".join(data_lines[i + 1 :]) break else: appending_to.append(line) + self.substeps = self.substeps[1:] def __repr__(self) -> str: return f"Step(question={self.question}, substeps={self.substeps})"