Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagates llmclient changes to pqa #839

Merged
merged 53 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
efd4d01
Changed MultiCompletionLLMModel to LiteLLMModel
maykcaldas Jan 29, 2025
67b6175
pinned llmclient and ldp versions to the repos
maykcaldas Jan 29, 2025
645cbdd
Updated run_prompt to call
maykcaldas Jan 29, 2025
e2b3a20
Fixed tool_choice injection in test_sucessful_memory_agent
maykcaldas Jan 29, 2025
fb45950
Removed default user role in single Messages lists
maykcaldas Jan 29, 2025
2c75f6f
Addressed suggestions
maykcaldas Jan 29, 2025
ca33a08
cleaned some kwargs in string formatting
maykcaldas Jan 29, 2025
594776b
Renamed model to name in settings
maykcaldas Jan 30, 2025
a941564
Removed gather_with_concurrency from pqa
maykcaldas Jan 31, 2025
2eac706
Merge branch 'main' into update-llmclient
maykcaldas Feb 10, 2025
5e8c44a
updated uv.lock
maykcaldas Feb 10, 2025
af2b92b
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Feb 10, 2025
ea83591
Merge branch 'main' into update-llmclient
maykcaldas Feb 10, 2025
bca04bf
Removed deprecation comment and simplified test prompt
maykcaldas Feb 11, 2025
7ab10dc
Removed legacy comment
maykcaldas Feb 11, 2025
6111325
Fixed typos
maykcaldas Feb 11, 2025
d1d4c38
Avoided mocking rate_limited
maykcaldas Feb 11, 2025
211a5c1
Merge branch 'main' into update-llmclient
maykcaldas Feb 11, 2025
d87e3f9
Resolved ARG002
maykcaldas Feb 11, 2025
bfe8ec1
Merge branch 'main' into update-llmclient
maykcaldas Feb 11, 2025
a546bb3
Added 'noqa' to avoid deleting unneeded args
maykcaldas Feb 11, 2025
98de4aa
Passed messages list directly to keyword argument
maykcaldas Feb 11, 2025
a8ed485
renamed variable to prompt_template to make it more explicit
maykcaldas Feb 11, 2025
58875ac
Formatting code
maykcaldas Feb 11, 2025
9d79094
Renamed prompt_details to prompt_templates
maykcaldas Feb 11, 2025
c4336e9
Fixed wrong argument calling in max_fxn_summary
maykcaldas Feb 11, 2025
76ae1c4
Merge branch 'main' into update-llmclient
maykcaldas Feb 11, 2025
c5cdbf1
Allowed override using extra_prompt_data in map_fxn_summary
maykcaldas Feb 12, 2025
1659a9e
Improved map_fxn_summary documentation and removed old comment
maykcaldas Feb 13, 2025
8b5fd50
Adds deepseek extra
maykcaldas Feb 13, 2025
a53ff2f
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Feb 13, 2025
31865d7
Added deepseek extra to pre-commit-config
maykcaldas Feb 13, 2025
cc5ab9a
Merge branch 'main' into update-llmclient
maykcaldas Feb 13, 2025
426b0df
migrate recently merged PR
maykcaldas Feb 13, 2025
7bb5fe2
Removed .format from non-template string
maykcaldas Feb 13, 2025
5769fdc
Fixed misatributed prompts
maykcaldas Feb 13, 2025
3c6cac0
Merge branch 'main' into update-llmclient
maykcaldas Feb 19, 2025
df001c4
Updated fhaviary version
maykcaldas Feb 19, 2025
bdb0a15
Removed deprecated task.py file
maykcaldas Feb 19, 2025
da9b7c9
[pre-commit.ci lite] apply automatic fixes
pre-commit-ci-lite[bot] Feb 19, 2025
244bbf0
ran prettier
maykcaldas Feb 19, 2025
9eacdc5
Merge branch 'update-llmclient' of github.com:Future-House/paper-qa i…
maykcaldas Feb 19, 2025
c449cb4
Merge branch 'main' into update-llmclient
maykcaldas Feb 19, 2025
89d75f0
fixed typo in mailmap and avoided unneeded dict in formatting string
maykcaldas Feb 20, 2025
d87d932
Merge branch 'update-llmclient' of github.com:Future-House/paper-qa i…
maykcaldas Feb 20, 2025
bc4b3bf
Merge branch 'main' into update-llmclient
maykcaldas Feb 20, 2025
b3b0455
Bumped llmclient and ldp versions
maykcaldas Feb 20, 2025
6273c01
Merge branch 'update-llmclient' of github.com:Future-House/paper-qa i…
maykcaldas Feb 20, 2025
1a15993
Merge branch 'main' into update-llmclient
maykcaldas Feb 20, 2025
5835af7
Merge branch 'main' into update-llmclient
maykcaldas Feb 21, 2025
7a84534
Updated pqa to import from lmi
maykcaldas Feb 21, 2025
51f9812
Merge branch 'update-llmclient' of github.com:Future-House/paper-qa i…
maykcaldas Feb 21, 2025
e30231a
Updated import to lmi
maykcaldas Feb 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ repos:
- aiohttp>=3.10.6 # Match pyproject.toml
- PyMuPDF>=1.24.12
- anyio
- fh-llm-client[deepseek]>=0.0.11 # Match pyproject.toml
- fh-llm-client[deepseek] @ git+https://github.com/Future-House/llm-client.git@main
- fhaviary[llm]>=0.14 # Match pyproject.toml
- ldp>=0.20 # Match pyproject.toml
- html2text
Expand Down
12 changes: 7 additions & 5 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import cast

from aviary.core import Message
from llmclient import LiteLLMModel, LLMModel
from rich.table import Table

Expand Down Expand Up @@ -60,12 +61,13 @@ async def litellm_get_search_query(
)
else:
model = llm
result = await model.run_prompt(
prompt=search_prompt,
data={"question": question, "count": count},
system_prompt=None,
messages = [
Message(content=search_prompt.format(question=question, count=count)),
]
result = await model.call_single(
messages=messages,
)
search_query = result.text
search_query = cast(str, result.text)
queries = [s for s in search_query.split("\n") if len(s) > 3] # noqa: PLR2004
# remove "2.", "3.", etc. -- https://regex101.com/r/W2f7F1/1
queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries]
Expand Down
21 changes: 15 additions & 6 deletions paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import time
from contextlib import asynccontextmanager
from enum import StrEnum
from typing import Any, ClassVar, Protocol
from typing import Any, ClassVar, Protocol, cast
from uuid import UUID, uuid4

from aviary.core import Message
from llmclient import LiteLLMModel, LLMModel
from pydantic import (
BaseModel,
Expand Down Expand Up @@ -79,12 +80,20 @@ async def get_summary(self, llm_model: LLMModel | str = "gpt-4o") -> str:
model = (
LiteLLMModel(name=llm_model) if isinstance(llm_model, str) else llm_model
)
result = await model.run_prompt(
prompt="{question}\n\n{answer}",
data={"question": self.session.question, "answer": self.session.answer},
system_prompt=sys_prompt,
prompt_template = "{question}\n\n{answer}"
messages = [
Message(role="system", content=sys_prompt),
Message(
role="user",
content=prompt_template.format(
question=self.session.question, answer=self.session.answer
),
),
]
result = await model.call_single(
messages=messages,
)
return result.text.strip()
return cast(str, result.text).strip()


class TimerData(BaseModel):
Expand Down
16 changes: 9 additions & 7 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TASK_DATASET_REGISTRY,
Environment,
Frame,
Message,
Messages,
TaskDataset,
ToolRequestMessage,
Expand Down Expand Up @@ -549,13 +550,14 @@ async def _evaluate_answer(self) -> dict:
"answer2": self.human_answer if pqa_answer_index == 1 else pqa_answer,
}

result = await pairwise_eval_llm.run_prompt(
prompt=lfrqa_prompt_template,
data=data,
system_prompt=lfrqa_system_prompt,
result = await pairwise_eval_llm.call_single(
messages=[
Message(role="system", content=lfrqa_system_prompt),
Message(role="user", content=lfrqa_prompt_template.format(**data)),
]
)

best_answer_index = self.extract_best_answer_index(result.text)
result_text = cast(str, result.text)
best_answer_index = self.extract_best_answer_index(result_text)
if best_answer_index == pqa_answer_index:
winner, reward = "paperqa", self._rewards["win"]
elif best_answer_index != 0:
Expand All @@ -574,7 +576,7 @@ async def _evaluate_answer(self) -> dict:
"paper_search_ids": paper_search_ids,
"gt_doc_ids": self.gt_doc_ids,
"pqa_answer_was_answer_1": pqa_answer_index == 1,
"complete_evaluator_response": result.text,
"complete_evaluator_response": result_text,
"reward": reward,
}

Expand Down
2 changes: 1 addition & 1 deletion paperqa/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Any, cast

import aiohttp
from llmclient.utils import gather_with_concurrency
from pydantic import BaseModel, ConfigDict

from paperqa.types import Doc, DocDetails
from paperqa.utils import gather_with_concurrency

from .client_models import MetadataPostProcessor, MetadataProvider
from .crossref import CrossrefProvider
Expand Down
44 changes: 28 additions & 16 deletions paperqa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import json
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, cast

from aviary.core import Message
from llmclient import LLMModel

from paperqa.llms import PromptRunner
from paperqa.types import Context, LLMResult, Text
from paperqa.utils import extract_score, strip_citations

Expand Down Expand Up @@ -42,7 +44,8 @@ def escape_double_backslashes(match: re.Match) -> str:
async def map_fxn_summary(
text: Text,
question: str,
prompt_runner: PromptRunner | None,
summary_llm_model: LLMModel | None,
prompt_templates: tuple[str, str] | None,
extra_prompt_data: dict[str, str] | None = None,
parser: Callable[[str], dict[str, Any]] | None = None,
callbacks: Sequence[Callable[[str], None]] | None = None,
Expand All @@ -55,12 +58,14 @@ async def map_fxn_summary(

Args:
text: The text to parse.
question: The question to use for the chain.
prompt_runner: The prompt runner to call - should have question, citation,
summary_length, and text fields.
extra_prompt_data: Optional extra kwargs to pass to the prompt runner's data.
parser: The parser to use for parsing - return empty dict on Failure to fallback to text parsing.
callbacks: LLM callbacks to execute in the prompt runner.
question: The question to use for summarization.
summary_llm_model: The LLM model to use for generating summaries.
prompt_templates: Optional two-elements tuple containing templates for the user and system prompts.
prompt_templates = (user_prompt_template, system_prompt_template)
extra_prompt_data: Optional extra data to pass to the prompt template.
parser: Optional parser function to parse LLM output into structured data.
Should return dict with at least 'summary' field.
callbacks: Optional sequence of callback functions to execute during LLM calls.

Returns:
The context object and LLMResult to get info about the LLM execution.
Expand All @@ -71,14 +76,21 @@ async def map_fxn_summary(
citation = text.name + ": " + text.doc.formatted_citation
success = False

if prompt_runner:
llm_result = await prompt_runner(
{"question": question, "citation": citation, "text": text.text}
| (extra_prompt_data or {}),
callbacks,
"evidence:" + text.name,
if summary_llm_model and prompt_templates:
data = {"question": question, "citation": citation, "text": text.text} | (
extra_prompt_data or {}
)
message_prompt, system_prompt = prompt_templates
messages = [
Message(role="system", content=system_prompt.format(**data)),
Message(role="user", content=message_prompt.format(**data)),
]
llm_result = await summary_llm_model.call_single(
messages=messages,
callbacks=callbacks,
name="evidence:" + text.name,
)
context = llm_result.text
context = cast(str, llm_result.text)
result_data = parser(context) if parser else {}
success = bool(result_data)
if success:
Expand Down
98 changes: 58 additions & 40 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
import urllib.request
from collections.abc import Callable, Sequence
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, BinaryIO, cast
from uuid import UUID, uuid4

from aviary.core import Message
from llmclient import (
Embeddable,
EmbeddingModel,
LLMModel,
LLMResult,
)
from llmclient.types import set_llm_session_ids
from llmclient.utils import gather_with_concurrency
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -33,7 +34,6 @@
from paperqa.core import llm_parse_json, map_fxn_summary
from paperqa.llms import (
NumpyVectorStore,
PromptRunner,
VectorStore,
)
from paperqa.paths import PAPERQA_DIR
Expand All @@ -43,7 +43,6 @@
from paperqa.types import Doc, DocDetails, DocKey, PQASession, Text
from paperqa.utils import (
citation_to_docname,
gather_with_concurrency,
get_loop,
maybe_is_html,
maybe_is_pdf,
Expand Down Expand Up @@ -294,12 +293,14 @@ async def aadd( # noqa: PLR0912
)
if not texts:
raise ValueError(f"Could not read document {path}. Is it empty?")
result = await llm_model.run_prompt(
prompt=parse_config.citation_prompt,
data={"text": texts[0].text},
system_prompt=None, # skip system because it's too hesitant to answer
result = await llm_model.call_single(
messages=[
Message(
content=parse_config.citation_prompt.format(text=texts[0].text)
),
],
)
citation = result.text
citation = cast(str, result.text)
if (
len(citation) < 3 # noqa: PLR2004
or "Unknown" in citation
Expand All @@ -315,18 +316,23 @@ async def aadd( # noqa: PLR0912
# try to extract DOI / title from the citation
if (doi is title is None) and parse_config.use_doc_details:
# TODO: specify a JSON schema here when many LLM providers support this
result = await llm_model.run_prompt(
prompt=parse_config.structured_citation_prompt,
data={"citation": citation},
system_prompt=None,
messages = [
Message(
content=parse_config.structured_citation_prompt.format(
citation=citation
),
),
]
result = await llm_model.call_single(
messages=messages,
)
# This code below tries to isolate the JSON
# based on observed messages from LLMs
# it does so by isolating the content between
# the first { and last } in the response.
# Since the anticipated structure should not be nested,
# we don't have to worry about nested curlies.
clean_text = result.text.split("{", 1)[-1].split("}", 1)[0]
clean_text = cast(str, result.text).split("{", 1)[-1].split("}", 1)[0]
clean_text = "{" + clean_text + "}"
try:
citation_json = json.loads(clean_text)
Expand Down Expand Up @@ -609,19 +615,17 @@ async def aget_evidence(
else matches
)

prompt_runner: PromptRunner | None = None
prompt_templates = None
if not answer_config.evidence_skip_summary:
if prompt_config.use_json:
prompt_runner = partial(
summary_llm_model.run_prompt,
prompt_templates = (
prompt_config.summary_json,
system_prompt=prompt_config.summary_json_system,
prompt_config.summary_json_system,
)
else:
prompt_runner = partial(
summary_llm_model.run_prompt,
prompt_templates = (
prompt_config.summary,
system_prompt=prompt_config.system,
prompt_config.system,
)

with set_llm_session_ids(session.id):
Expand All @@ -631,7 +635,8 @@ async def aget_evidence(
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
summary_llm_model=summary_llm_model,
prompt_templates=prompt_templates,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
Expand Down Expand Up @@ -712,12 +717,17 @@ async def aquery( # noqa: PLR0912
pre_str = None
if prompt_config.pre is not None:
with set_llm_session_ids(session.id):
pre = await llm_model.run_prompt(
prompt=prompt_config.pre,
data={"question": session.question},
messages = [
Message(role="system", content=prompt_config.system),
Message(
role="user",
content=prompt_config.pre.format(question=session.question),
),
]
pre = await llm_model.call_single(
messages=messages,
callbacks=callbacks,
name="pre",
system_prompt=prompt_config.system,
)
session.add_tokens(pre)
pre_str = pre.text
Expand Down Expand Up @@ -766,19 +776,22 @@ async def aquery( # noqa: PLR0912
)
else:
with set_llm_session_ids(session.id):
answer_result = await llm_model.run_prompt(
prompt=prompt_config.qa,
data={
"context": context_str,
"answer_length": answer_config.answer_length,
"question": session.question,
"example_citation": prompt_config.EXAMPLE_CITATION,
},
data = {
"context": context_str,
"answer_length": answer_config.answer_length,
"question": session.question,
"example_citation": prompt_config.EXAMPLE_CITATION,
}
messages = [
Message(role="system", content=prompt_config.system),
Message(role="user", content=prompt_config.qa.format(**data)),
]
answer_result = await llm_model.call_single(
messages=messages,
callbacks=callbacks,
name="answer",
system_prompt=prompt_config.system,
)
answer_text = answer_result.text
answer_text = cast(str, answer_result.text)
session.add_tokens(answer_result)
# it still happens
if (ex_citation := prompt_config.EXAMPLE_CITATION) in answer_text:
Expand Down Expand Up @@ -806,14 +819,19 @@ async def aquery( # noqa: PLR0912

if prompt_config.post is not None:
with set_llm_session_ids(session.id):
post = await llm_model.run_prompt(
prompt=prompt_config.post,
data=session.model_dump(),
messages = [
Message(role="system", content=prompt_config.system),
Message(
role="user",
content=prompt_config.post.format(question=session.question),
),
]
post = await llm_model.call_single(
messages=messages,
callbacks=callbacks,
name="post",
system_prompt=prompt_config.system,
)
answer_text = post.text
answer_text = cast(str, post.text)
session.add_tokens(post)
formatted_answer = f"Question: {session.question}\n\n{post}\n"
if bib:
Expand Down
Loading