Skip to content

Commit 423db03

Browse files
committed
Added consensus sampling helper function and storage callback
1 parent c221213 commit 423db03

File tree

2 files changed

+110
-7
lines changed

2 files changed

+110
-7
lines changed

paperqa/_ldp_shims.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"UIndexMemoryModel",
1616
"_Memories",
1717
"discounted_returns",
18+
"evaluate_consensus",
1819
"set_training_mode",
1920
]
2021

@@ -29,7 +30,12 @@
2930
SimpleAgent,
3031
SimpleAgentState,
3132
)
32-
from ldp.alg import Callback, ComputeTrajectoryMetricsMixin, RolloutManager
33+
from ldp.alg import (
34+
Callback,
35+
ComputeTrajectoryMetricsMixin,
36+
RolloutManager,
37+
evaluate_consensus,
38+
)
3339
from ldp.graph.memory import Memory, UIndexMemoryModel
3440
from ldp.graph.op_utils import set_training_mode
3541
from ldp.utils import discounted_returns
@@ -57,4 +63,5 @@ class Callback: # type: ignore[no-redef]
5763
SimpleAgentState = None # type: ignore[assignment,misc]
5864
UIndexMemoryModel = None # type: ignore[assignment,misc]
5965
discounted_returns = None # type: ignore[assignment]
66+
evaluate_consensus = None # type: ignore[assignment]
6067
set_training_mode = None # type: ignore[assignment]

paperqa/agents/task.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
import logging
1111
import re
1212
from abc import ABC
13-
from collections.abc import Awaitable, Callable, Mapping, Sequence
13+
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
1414
from copy import deepcopy
1515
from enum import StrEnum
16-
from typing import TYPE_CHECKING, Any, Self, assert_never
16+
from typing import TYPE_CHECKING, Any, Self, assert_never, cast
1717

1818
from aviary.core import (
1919
TASK_DATASET_REGISTRY,
20+
Environment,
21+
Frame,
2022
Messages,
2123
TaskDataset,
2224
ToolRequestMessage,
@@ -30,22 +32,27 @@
3032
)
3133
from llmclient import EmbeddingModel, LiteLLMModel, LLMModel
3234

33-
from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin
35+
from paperqa._ldp_shims import (
36+
Callback,
37+
ComputeTrajectoryMetricsMixin,
38+
evaluate_consensus,
39+
)
3440
from paperqa.docs import Docs
3541
from paperqa.litqa import (
3642
DEFAULT_LABBENCH_HF_HUB_NAME,
3743
DEFAULT_REWARD_MAPPING,
3844
read_litqa_v2_from_hub,
3945
)
40-
from paperqa.types import DocDetails
46+
from paperqa.types import DocDetails, PQASession
4147

4248
from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
4349
from .models import QueryRequest
4450
from .search import SearchIndex, maybe_get_manifest
45-
from .tools import Complete
51+
from .tools import Complete, EnvironmentState
4652

4753
if TYPE_CHECKING:
48-
from ldp.data_structures import Trajectory
54+
from ldp.agent import Agent
55+
from ldp.data_structures import Trajectory, Transition
4956

5057
logger = logging.getLogger(__name__)
5158

@@ -169,6 +176,95 @@ def __deepcopy__(self, memo) -> Self:
169176
)
170177

171178

179+
async def evaluate_consensus_sampling(
180+
data: Iterable[GradablePaperQAEnvironment | Frame],
181+
num_samples: int = 1,
182+
seed: int | None = None,
183+
) -> tuple[dict[str, list[tuple[str, int]]], float]:
184+
def extract_question(x: GradablePaperQAEnvironment | Frame) -> str:
185+
if isinstance(x, GradablePaperQAEnvironment):
186+
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
187+
else:
188+
qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment]
189+
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
190+
if isinstance(query, str):
191+
return query
192+
if isinstance(query, MultipleChoiceQuestion):
193+
return query.question_prompt
194+
return query["question"]
195+
196+
def extract_answer(x: GradablePaperQAEnvironment | Frame) -> str:
197+
sess: PQASession | dict[str, Any] = (
198+
x.state.session
199+
if isinstance(x.state, EnvironmentState)
200+
else cast(PQASession | dict[str, Any], x.state["session"]) # type: ignore[call-overload,index]
201+
)
202+
return (
203+
sess.graded_answer
204+
if isinstance(sess, PQASession)
205+
else sess["graded_answer"]
206+
) or ""
207+
208+
def extract_ideal(x: GradablePaperQAEnvironment | Frame) -> str:
209+
if isinstance(x, GradablePaperQAEnvironment):
210+
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
211+
else:
212+
qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment]
213+
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
214+
if isinstance(query, str):
215+
raise ValueError( # noqa: TRY004
216+
"We require a {MultipleChoiceQuestion.__name__} variant to extract"
217+
" ideal answer, not a string."
218+
)
219+
if isinstance(query, MultipleChoiceQuestion):
220+
return query.ideal_answer
221+
return query["ideal_answer"]
222+
223+
try:
224+
return await evaluate_consensus(
225+
data=data,
226+
grouping_fn=extract_question,
227+
extract_answer_fn=extract_answer,
228+
ideal_answer_fn=extract_ideal,
229+
num_samples=num_samples,
230+
seed=seed,
231+
)
232+
except TypeError:
233+
raise ImportError(
234+
"Evaluating consensus requires the 'ldp' extra for 'ldp'. Please:"
235+
" `pip install paper-qa[ldp]`."
236+
) from None
237+
238+
239+
class StoreForConsensusSamplingCallback(Callback):
240+
def __init__(self):
241+
super().__init__()
242+
self.stored: list[GradablePaperQAEnvironment | Frame] = []
243+
244+
async def after_transition(
245+
self,
246+
traj_id: str, # noqa: ARG002
247+
agent: "Agent", # noqa: ARG002
248+
env: Environment,
249+
transition: "Transition",
250+
) -> None:
251+
if not isinstance(env, GradablePaperQAEnvironment):
252+
raise NotImplementedError(
253+
f"So far only handled {GradablePaperQAEnvironment} in this callback,"
254+
f" not {type(env)}."
255+
)
256+
if not transition.done: # Only store once
257+
return
258+
self.stored.append(env)
259+
260+
async def evaluate_consensus_sampling(
261+
self, num_samples: int = 1, seed: int | None = None
262+
) -> tuple[dict[str, list[tuple[str, int]]], float]:
263+
return await evaluate_consensus_sampling(
264+
data=self.stored, num_samples=num_samples, seed=seed
265+
)
266+
267+
172268
class LitQATaskDataset(
173269
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
174270
):

0 commit comments

Comments
 (0)