Skip to content

Commit

Permalink
Added consensus sampling helper function and storage callback
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Jan 2, 2025
1 parent c221213 commit 423db03
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 7 deletions.
9 changes: 8 additions & 1 deletion paperqa/_ldp_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"UIndexMemoryModel",
"_Memories",
"discounted_returns",
"evaluate_consensus",
"set_training_mode",
]

Expand All @@ -29,7 +30,12 @@
SimpleAgent,
SimpleAgentState,
)
from ldp.alg import Callback, ComputeTrajectoryMetricsMixin, RolloutManager
from ldp.alg import (
Callback,
ComputeTrajectoryMetricsMixin,
RolloutManager,
evaluate_consensus,
)
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.op_utils import set_training_mode
from ldp.utils import discounted_returns
Expand Down Expand Up @@ -57,4 +63,5 @@ class Callback: # type: ignore[no-redef]
SimpleAgentState = None # type: ignore[assignment,misc]
UIndexMemoryModel = None # type: ignore[assignment,misc]
discounted_returns = None # type: ignore[assignment]
evaluate_consensus = None # type: ignore[assignment]
set_training_mode = None # type: ignore[assignment]
108 changes: 102 additions & 6 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import logging
import re
from abc import ABC
from collections.abc import Awaitable, Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Self, assert_never
from typing import TYPE_CHECKING, Any, Self, assert_never, cast

from aviary.core import (
TASK_DATASET_REGISTRY,
Environment,
Frame,
Messages,
TaskDataset,
ToolRequestMessage,
Expand All @@ -30,22 +32,27 @@
)
from llmclient import EmbeddingModel, LiteLLMModel, LLMModel

from paperqa._ldp_shims import ComputeTrajectoryMetricsMixin
from paperqa._ldp_shims import (
Callback,
ComputeTrajectoryMetricsMixin,
evaluate_consensus,
)
from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_MAPPING,
read_litqa_v2_from_hub,
)
from paperqa.types import DocDetails
from paperqa.types import DocDetails, PQASession

from .env import POPULATE_FROM_SETTINGS, PaperQAEnvironment
from .models import QueryRequest
from .search import SearchIndex, maybe_get_manifest
from .tools import Complete
from .tools import Complete, EnvironmentState

if TYPE_CHECKING:
from ldp.data_structures import Trajectory
from ldp.agent import Agent
from ldp.data_structures import Trajectory, Transition

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -169,6 +176,95 @@ def __deepcopy__(self, memo) -> Self:
)


async def evaluate_consensus_sampling(
data: Iterable[GradablePaperQAEnvironment | Frame],
num_samples: int = 1,
seed: int | None = None,
) -> tuple[dict[str, list[tuple[str, int]]], float]:
def extract_question(x: GradablePaperQAEnvironment | Frame) -> str:
if isinstance(x, GradablePaperQAEnvironment):
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
else:
qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment]
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
if isinstance(query, str):
return query
if isinstance(query, MultipleChoiceQuestion):
return query.question_prompt
return query["question"]

def extract_answer(x: GradablePaperQAEnvironment | Frame) -> str:
sess: PQASession | dict[str, Any] = (
x.state.session
if isinstance(x.state, EnvironmentState)
else cast(PQASession | dict[str, Any], x.state["session"]) # type: ignore[call-overload,index]
)
return (
sess.graded_answer
if isinstance(sess, PQASession)
else sess["graded_answer"]
) or ""

def extract_ideal(x: GradablePaperQAEnvironment | Frame) -> str:
if isinstance(x, GradablePaperQAEnvironment):
query: str | MultipleChoiceQuestion | dict[str, Any] = x._query.query
else:
qr: QueryRequest | dict[str, Any] = x.info["query"] # type: ignore[call-overload,index,assignment]
query = qr.query if isinstance(qr, QueryRequest) else qr["query"]
if isinstance(query, str):
raise ValueError( # noqa: TRY004
"We require a {MultipleChoiceQuestion.__name__} variant to extract"
" ideal answer, not a string."
)
if isinstance(query, MultipleChoiceQuestion):
return query.ideal_answer
return query["ideal_answer"]

try:
return await evaluate_consensus(
data=data,
grouping_fn=extract_question,
extract_answer_fn=extract_answer,
ideal_answer_fn=extract_ideal,
num_samples=num_samples,
seed=seed,
)
except TypeError:
raise ImportError(
"Evaluating consensus requires the 'ldp' extra for 'ldp'. Please:"
" `pip install paper-qa[ldp]`."
) from None


class StoreForConsensusSamplingCallback(Callback):
def __init__(self):
super().__init__()
self.stored: list[GradablePaperQAEnvironment | Frame] = []

async def after_transition(
self,
traj_id: str, # noqa: ARG002
agent: "Agent", # noqa: ARG002
env: Environment,
transition: "Transition",
) -> None:
if not isinstance(env, GradablePaperQAEnvironment):
raise NotImplementedError(
f"So far only handled {GradablePaperQAEnvironment} in this callback,"
f" not {type(env)}."
)
if not transition.done: # Only store once
return
self.stored.append(env)

async def evaluate_consensus_sampling(
self, num_samples: int = 1, seed: int | None = None
) -> tuple[dict[str, list[tuple[str, int]]], float]:
return await evaluate_consensus_sampling(
data=self.stored, num_samples=num_samples, seed=seed
)


class LitQATaskDataset(
TaskDataset[GradablePaperQAEnvironment], ComputeTrajectoryMetricsMixin, ABC
):
Expand Down

0 comments on commit 423db03

Please sign in to comment.