|
10 | 10 | import logging
|
11 | 11 | import re
|
12 | 12 | from abc import ABC
|
13 |
| -from collections.abc import Awaitable, Callable, Mapping, Sequence |
| 13 | +from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence |
14 | 14 | from copy import deepcopy
|
15 | 15 | from enum import StrEnum
|
16 |
| -from typing import TYPE_CHECKING, Any, Self, assert_never |
| 16 | +from typing import TYPE_CHECKING, Any, Self, assert_never, cast |
17 | 17 |
|
18 | 18 | from aviary.core import (
|
19 | 19 | TASK_DATASET_REGISTRY,
|
| 20 | + Environment, |
| 21 | + Frame, |
20 | 22 | Messages,
|
21 | 23 | TaskDataset,
|
22 | 24 | ToolRequestMessage,
|
|
30 | 32 | )
|
31 | 33 | from llmclient import EmbeddingModel, LiteLLMModel, LLMModel
|
32 | 34 |
|
33 |
| -from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin |
| 35 | +from paperqa._ldp_shims import ( |
| 36 | + Callback, |
| 37 | + ComputeTrajectoryMetricsMixin, |
| 38 | + evaluate_consensus, |
| 39 | +) |
34 | 40 | from paperqa.docs import Docs
|
35 | 41 | from paperqa.litqa import (
|
36 | 42 | DEFAULT_LABBENCH_HF_HUB_NAME,
|
37 | 43 | DEFAULT_REWARD_MAPPING,
|
38 | 44 | read_litqa_v2_from_hub,
|
39 | 45 | )
|
40 |
| -from paperqa.types import DocDetails |
| 46 | +from paperqa.types import DocDetails, PQASession |
41 | 47 |
|
42 | 48 | from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
|
43 | 49 | from .models import QueryRequest
|
44 | 50 | from .search import SearchIndex, maybe_get_manifest
|
45 |
| -from .tools import Complete |
| 51 | +from .tools import Complete, EnvironmentState |
46 | 52 |
|
47 | 53 | if TYPE_CHECKING:
|
48 |
| - from ldp.data_structures import Trajectory |
| 54 | + from ldp.agent import Agent |
| 55 | + from ldp.data_structures import Trajectory, Transition |
49 | 56 |
|
50 | 57 | logger = logging.getLogger(__name__)
|
51 | 58 |
|
@@ -169,6 +176,95 @@ def __deepcopy__(self, memo) -> Self:
|
169 | 176 | )
|
170 | 177 |
|
171 | 178 |
|
| 179 | +async def evaluate_consensus_sampling( |
| 180 | + data: Iterable[GradablePaperQAEnvironment | Frame], |
| 181 | + num_samples: int = 1, |
| 182 | + seed: int | None = None, |
| 183 | +) -> tuple[dict[str, list[tuple[str, int]]], float]: |
| 184 | + def extract_question(x: GradablePaperQAEnvironment | Frame) -> str: |
| 185 | + if isinstance(x, GradablePaperQAEnvironment): |
| 186 | + query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query |
| 187 | + else: |
| 188 | + qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment] |
| 189 | + query = qr.query if isinstance(qr, QueryRequest) else qr["query"] |
| 190 | + if isinstance(query, str): |
| 191 | + return query |
| 192 | + if isinstance(query, MultipleChoiceQuestion): |
| 193 | + return query.question_prompt |
| 194 | + return query["question"] |
| 195 | + |
| 196 | + def extract_answer(x: GradablePaperQAEnvironment | Frame) -> str: |
| 197 | + sess: PQASession | dict[str, Any] = ( |
| 198 | + x.state.session |
| 199 | + if isinstance(x.state, EnvironmentState) |
| 200 | + else cast(PQASession | dict[str, Any], x.state["session"]) # type: ignore[call-overload,index] |
| 201 | + ) |
| 202 | + return ( |
| 203 | + sess.graded_answer |
| 204 | + if isinstance(sess, PQASession) |
| 205 | + else sess["graded_answer"] |
| 206 | + ) or "" |
| 207 | + |
| 208 | + def extract_ideal(x: GradablePaperQAEnvironment | Frame) -> str: |
| 209 | + if isinstance(x, GradablePaperQAEnvironment): |
| 210 | + query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query |
| 211 | + else: |
| 212 | + qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment] |
| 213 | + query = qr.query if isinstance(qr, QueryRequest) else qr["query"] |
| 214 | + if isinstance(query, str): |
| 215 | + raise ValueError( # noqa: TRY004 |
| 216 | + "We require a {MultipleChoiceQuestion.__name__} variant to extract" |
| 217 | + " ideal answer, not a string." |
| 218 | + ) |
| 219 | + if isinstance(query, MultipleChoiceQuestion): |
| 220 | + return query.ideal_answer |
| 221 | + return query["ideal_answer"] |
| 222 | + |
| 223 | + try: |
| 224 | + return await evaluate_consensus( |
| 225 | + data=data, |
| 226 | + grouping_fn=extract_question, |
| 227 | + extract_answer_fn=extract_answer, |
| 228 | + ideal_answer_fn=extract_ideal, |
| 229 | + num_samples=num_samples, |
| 230 | + seed=seed, |
| 231 | + ) |
| 232 | + except TypeError: |
| 233 | + raise ImportError( |
| 234 | + "Evaluating consensus requires the 'ldp' extra for 'ldp'. Please:" |
| 235 | + " `pip install paper-qa[ldp]`." |
| 236 | + ) from None |
| 237 | + |
| 238 | + |
| 239 | +class StoreForConsensusSamplingCallback(Callback): |
| 240 | + def __init__(self): |
| 241 | + super().__init__() |
| 242 | + self.stored: list[GradablePaperQAEnvironment | Frame] = [] |
| 243 | + |
| 244 | + async def after_transition( |
| 245 | + self, |
| 246 | + traj_id: str, # noqa: ARG002 |
| 247 | + agent: "Agent", # noqa: ARG002 |
| 248 | + env: Environment, |
| 249 | + transition: "Transition", |
| 250 | + ) -> None: |
| 251 | + if not isinstance(env, GradablePaperQAEnvironment): |
| 252 | + raise NotImplementedError( |
| 253 | + f"So far only handled {GradablePaperQAEnvironment} in this callback," |
| 254 | + f" not {type(env)}." |
| 255 | + ) |
| 256 | + if not transition.done: # Only store once |
| 257 | + return |
| 258 | + self.stored.append(env) |
| 259 | + |
| 260 | + async def evaluate_consensus_sampling( |
| 261 | + self, num_samples: int = 1, seed: int | None = None |
| 262 | + ) -> tuple[dict[str, list[tuple[str, int]]], float]: |
| 263 | + return await evaluate_consensus_sampling( |
| 264 | + data=self.stored, num_samples=num_samples, seed=seed |
| 265 | + ) |
| 266 | + |
| 267 | + |
172 | 268 | class LitQATaskDataset(
|
173 | 269 | TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
|
174 | 270 | ):
|
|
0 commit comments