diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 31d160f6..c2033e1b 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -196,6 +196,7 @@ def __init__( base_query: QueryRequest | dict | None = None, base_docs: Docs | dict | None = None, rewards: Mapping[str, float] = DEFAULT_REWARD_MAPPING, + question_kwargs: Mapping[str, Any] | None = None, eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME, **env_kwargs, ): @@ -210,23 +211,23 @@ def __init__( base_docs = Docs(**base_docs) self._base_docs = base_docs self._rewards = rewards - self._env_kwargs = env_kwargs + self._question_kwargs = question_kwargs self._eval_model = eval_model + self._env_kwargs = env_kwargs def _make_gradable_environment( self, ideal: str, distractors: str | list[str], question: str, - use_unsure: bool = True, sources: str | list[str] | None = None, ) -> GradablePaperQAEnvironment: qa_prompt, evaluation_from_answer = LitQAEvaluation.from_question( ideal=ideal, distractors=distractors, question=question, - use_unsure=use_unsure, eval_model=self._eval_model, + **(self._question_kwargs or {}), ) query = self._base_query.model_copy() query.query = qa_prompt @@ -305,11 +306,14 @@ def __init__( self, *args, labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME, + read_data_kwargs: Mapping[str, Any] | None = None, split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL, **kwargs, ): super().__init__(*args, **kwargs) - train_df, eval_df = read_litqa_v2_from_hub(labbench_dataset) + train_df, eval_df = read_litqa_v2_from_hub( + labbench_dataset, **(read_data_kwargs or {}) + ) split = LitQAv2TaskSplit(split) if split == LitQAv2TaskSplit.TRAIN: self.data = train_df diff --git a/paperqa/litqa.py b/paperqa/litqa.py index 461970df..f585cd48 100644 --- a/paperqa/litqa.py +++ b/paperqa/litqa.py @@ -8,7 +8,7 @@ from ast import literal_eval from collections.abc import Awaitable, Callable, Mapping, Sequence from enum import StrEnum -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING, Literal, Self try: from ldp.utils import discounted_returns @@ -92,6 +92,7 @@ def make_mc_options( DEFAULT_EVAL_MODEL_NAME = "gpt-4-turbo-2024-04-09" DEFAULT_REWARD_MAPPING = {"correct": 1.0, "unsure": 0.1, "incorrect": -1.0} +SEED_USING_QUESTION: Literal["SEED_USING_QUESTION"] = "SEED_USING_QUESTION" # Sentinel class LitQAEvaluation(StrEnum): @@ -161,7 +162,7 @@ def from_question( question: str, use_unsure: bool = True, eval_model: LLMModel | str = DEFAULT_EVAL_MODEL_NAME, - seed: int | None = None, + seed: int | Literal["SEED_USING_QUESTION"] | None = None, ) -> tuple[str, Callable[[PQASession | str], Awaitable[LitQAEvaluation]]]: """ Create a LitQA question and an answer-to-evaluation function. @@ -174,11 +175,15 @@ def from_question( eval_model: Evaluation model to use for multiple choice letter extraction from a text answer. seed: Optional seed to use in randomization of multiple choice letters. + Optionally pass in the string literal "SEED_USING_QUESTION" to hash the + input question for the seed. Returns: Two-tuple of created LitQA question, function (that can be thought of as stateless) to use to extract an evaluation result from an answer. """ + if seed == SEED_USING_QUESTION: + seed = hash(question) text, ideal_answer, unsure_answer, distractor_answers = make_mc_options( ideal=ideal, distractors=distractors, diff --git a/tests/test_litqa.py b/tests/test_litqa.py index eac0a997..90d9ce25 100644 --- a/tests/test_litqa.py +++ b/tests/test_litqa.py @@ -3,7 +3,7 @@ import pytest -from paperqa.litqa import LitQAEvaluation, read_litqa_v2_from_hub +from paperqa.litqa import SEED_USING_QUESTION, LitQAEvaluation, read_litqa_v2_from_hub from tests.conftest import VCR_DEFAULT_MATCH_ON @@ -140,16 +140,38 @@ def test_consistent_mc_options(self) -> None: """Tests that creating multiple evaluations with the same seed results in the same prompt.""" question, ideal, distractors = self.MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS - qa_prompt_1, _ = LitQAEvaluation.from_question( + qa_prompt_1a, _ = LitQAEvaluation.from_question( ideal=ideal, distractors=distractors, question=question, seed=0 ) - self._assert_prompt_is_valid(qa_prompt_1, question, ideal, distractors) + self._assert_prompt_is_valid(qa_prompt_1a, question, ideal, distractors) - qa_prompt_2, _ = LitQAEvaluation.from_question( + qa_prompt_1b, _ = LitQAEvaluation.from_question( ideal=ideal, distractors=distractors, question=question, seed=0 ) - self._assert_prompt_is_valid(qa_prompt_1, question, ideal, distractors) - assert qa_prompt_1 == qa_prompt_2 + self._assert_prompt_is_valid(qa_prompt_1b, question, ideal, distractors) + assert qa_prompt_1a == qa_prompt_1b, "Same seeding should lead to same prompts" + + qa_prompt_2a, _ = LitQAEvaluation.from_question( + ideal=ideal, + distractors=distractors, + question=question, + seed=SEED_USING_QUESTION, + ) + self._assert_prompt_is_valid(qa_prompt_2a, question, ideal, distractors) + + qa_prompt_2b, _ = LitQAEvaluation.from_question( + ideal=ideal, + distractors=distractors, + question=question, + seed=SEED_USING_QUESTION, + ) + self._assert_prompt_is_valid(qa_prompt_2b, question, ideal, distractors) + assert ( + qa_prompt_2a == qa_prompt_2b + ), "Same seeding strategy should lead to same prompts" + assert ( + qa_prompt_2a != qa_prompt_1a + ), "Different seeding strategies should lead to different prompts" def test_creating_litqa_questions(self) -> None: """Test making LitQA eval questions after downloading from Hugging Face Hub.""" diff --git a/tests/test_task.py b/tests/test_task.py index 4ea30875..e1f14d83 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -20,7 +20,7 @@ LitQAv2TaskSplit, ) from paperqa.agents.tools import GenerateAnswer -from paperqa.litqa import DEFAULT_REWARD_MAPPING, LitQAEvaluation +from paperqa.litqa import DEFAULT_REWARD_MAPPING, SEED_USING_QUESTION, LitQAEvaluation @pytest.fixture(name="base_query_request") @@ -103,12 +103,27 @@ async def test___len__( expected_length: int, base_query_request: QueryRequest, ) -> None: - task_dataset = LitQAv2TaskDataset(base_query=base_query_request, split=split) + task_dataset = LitQAv2TaskDataset( + base_query=base_query_request, + question_kwargs={"seed": 42}, + read_data_kwargs={"seed": 42}, + split=split, + ) assert len(task_dataset) == expected_length # Now let's check we could use the sources in a validation for i in range(len(task_dataset)): env = task_dataset.get_new_env_by_idx(i) + if i == 0 and split == LitQAv2TaskSplit.TRAIN: + # Yes this assertion is somewhat brittle, but it reliably + # checks the seeding's behavior so we keep it + obs, _ = await env.reset() + assert ( + "Q: SLC14A1 been identified as a specific marker for endothelial" + " cells in which organ?\n\nOptions:\nA) heart\nB) eye\nC)" + " prostate\nD) Insufficient information to answer this question\nE)" + " liver" in (obs[0].content or "") + ) assert env.sources, "Sources need to be accessible" assert isinstance( env.sources, Iterable @@ -144,6 +159,7 @@ async def test_evaluation( "deleted_dockeys", } ), + "question_kwargs": {"seed": SEED_USING_QUESTION}, }, ) # NOTE: set base_query after construction of the TaskConfig. because in