From 8bfb740037118893097c9378e42b32c393b44eb9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:54:18 +0000 Subject: [PATCH] [pre-commit.ci lite] apply automatic fixes --- gradable.py | 10 ++++++++-- paperqa/agents/task.py | 13 +++++++++++-- paperqa/settings.py | 3 +-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/gradable.py b/gradable.py index d8db9b7e..889749f8 100644 --- a/gradable.py +++ b/gradable.py @@ -1,5 +1,6 @@ -import os import asyncio +import os + from aviary.env import TaskDataset from ldp.agent import SimpleAgent from ldp.alg.callbacks import MeanMetricsCallback @@ -8,6 +9,7 @@ from paperqa import Settings from paperqa.agents.task import LFRQATaskDataset + async def evaluate() -> None: settings = Settings() settings.agent.index.name = "lfrqa_science_index_complete" @@ -22,7 +24,11 @@ async def evaluate() -> None: settings.parsing.use_doc_details = False - dataset = LFRQATaskDataset(data_path="rag-qa-benchmarking/lfrqa/questions.csv", num_questions=2, settings=settings) + dataset = LFRQATaskDataset( + data_path="rag-qa-benchmarking/lfrqa/questions.csv", + num_questions=2, + settings=settings, + ) metrics_callback = MeanMetricsCallback(eval_dataset=dataset) evaluator = Evaluator( diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 8415b965..bcf048ca 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -578,8 +578,17 @@ async def pairwise_evaluation( print(f"PQa answer was:\n{pqa_answer} \n\n") print(f"Human answer was:\n{human_answer} \n\n") print(f"Winner is: {winner}\n") - self.log_results_to_json(self._settings.llm, qid, question, pqa_answer, human_answer, pqa_answer_index, winner, result) - + self.log_results_to_json( + self._settings.llm, + qid, + question, + pqa_answer, + human_answer, + pqa_answer_index, + winner, + result, + ) + reward = ( self._rewards["win"] if winner == "paperqa" diff --git a/paperqa/settings.py b/paperqa/settings.py index 1f0dc67c..4a462abc 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -837,7 +837,7 @@ def get_summary_llm(self) -> LiteLLMModel: self.summary_llm, self.temperature ), ) - + def get_pairwise_eval_llm(self) -> LiteLLMModel: return LiteLLMModel( name=self.pair_eval_llm, @@ -855,7 +855,6 @@ def get_agent_llm(self) -> LiteLLMModel: self.agent.agent_llm, self.temperature ), ) - def get_embedding_model(self) -> EmbeddingModel: return embedding_model_factory(self.embedding, **(self.embedding_config or {}))