Skip to content

Commit

Permalink
Updated run_prompt to call
Browse files Browse the repository at this point in the history
  • Loading branch information
maykcaldas committed Jan 29, 2025
1 parent 67b6175 commit 645cbdd
Show file tree
Hide file tree
Showing 8 changed files with 2,789 additions and 2,666 deletions.
13 changes: 8 additions & 5 deletions paperqa/agents/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import cast

from aviary.core import Message
from llmclient import LiteLLMModel, LLMModel
from rich.table import Table

Expand Down Expand Up @@ -60,12 +61,14 @@ async def litellm_get_search_query(
)
else:
model = llm
result = await model.run_prompt(
prompt=search_prompt,
data={"question": question, "count": count},
system_prompt=None,
data = {"question": question, "count": count}
messages = [
Message(role="user", content=search_prompt.format(**data)),
]
result = await model.call_single( # run_prompt is deprecated
messages=messages,
)
search_query = result.text
search_query = cast(str, result.text)
queries = [s for s in search_query.split("\n") if len(s) > 3] # noqa: PLR2004
# remove "2.", "3.", etc. -- https://regex101.com/r/W2f7F1/1
queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries]
Expand Down
17 changes: 11 additions & 6 deletions paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import time
from contextlib import asynccontextmanager
from enum import StrEnum
from typing import Any, ClassVar, Protocol
from typing import Any, ClassVar, Protocol, cast
from uuid import UUID, uuid4

from aviary.core import Message
from llmclient import LiteLLMModel, LLMModel
from pydantic import (
BaseModel,
Expand Down Expand Up @@ -79,12 +80,16 @@ async def get_summary(self, llm_model: LLMModel | str = "gpt-4o") -> str:
model = (
LiteLLMModel(name=llm_model) if isinstance(llm_model, str) else llm_model
)
result = await model.run_prompt(
prompt="{question}\n\n{answer}",
data={"question": self.session.question, "answer": self.session.answer},
system_prompt=sys_prompt,
data = {"question": self.session.question, "answer": self.session.answer}
prompt = "{question}\n\n{answer}"
messages = [
Message(role="system", content=sys_prompt),
Message(role="user", content=prompt.format(**data)),
]
result = await model.call_single( # run_prompt is deprecated
messages=messages,
)
return result.text.strip()
return cast(str, result.text).strip()


class TimerData(BaseModel):
Expand Down
47 changes: 31 additions & 16 deletions paperqa/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import json
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, cast

from aviary.core import Message
from llmclient import LLMModel

from paperqa.llms import PromptRunner
from paperqa.types import Context, LLMResult, Text
from paperqa.utils import extract_score, strip_citations

Expand Down Expand Up @@ -38,7 +40,8 @@ def escape_newlines(match: re.Match) -> str:
async def map_fxn_summary(
text: Text,
question: str,
prompt_runner: PromptRunner | None,
summary_llm_model: LLMModel | None,
prompt_details: tuple[str, str] | None,
extra_prompt_data: dict[str, str] | None = None,
parser: Callable[[str], dict[str, Any]] | None = None,
callbacks: Sequence[Callable[[str], None]] | None = None,
Expand All @@ -51,12 +54,13 @@ async def map_fxn_summary(
Args:
text: The text to parse.
question: The question to use for the chain.
prompt_runner: The prompt runner to call - should have question, citation,
summary_length, and text fields.
extra_prompt_data: Optional extra kwargs to pass to the prompt runner's data.
parser: The parser to use for parsing - return empty dict on Failure to fallback to text parsing.
callbacks: LLM callbacks to execute in the prompt runner.
question: The question to use for summarization.
summary_llm_model: The LLM model to use for generating summaries.
prompt_details: Tuple containing the prompt template and system prompt.
extra_prompt_data: Optional extra data to pass to the prompt template.
parser: Optional parser function to parse LLM output into structured data.
Should return dict with at least 'summary' field.
callbacks: Optional sequence of callback functions to execute during LLM calls.
Returns:
The context object and LLMResult to get info about the LLM execution.
Expand All @@ -67,14 +71,25 @@ async def map_fxn_summary(
citation = text.name + ": " + text.doc.formatted_citation
success = False

if prompt_runner:
llm_result = await prompt_runner(
{"question": question, "citation": citation, "text": text.text}
| (extra_prompt_data or {}),
callbacks,
"evidence:" + text.name,
if summary_llm_model and prompt_details:
data = {
"question": question,
"citation": citation,
"text": text.text,
**(extra_prompt_data or {}),
}
message_prompt = prompt_details[0]
system_prompt = prompt_details[1]
messages = [
Message(role="system", content=system_prompt.format(**data)),
Message(role="user", content=message_prompt.format(**data)),
]
llm_result = await summary_llm_model.call_single( # run_prompt is deprecated
messages=messages,
callbacks=callbacks,
name="evidence:" + text.name,
)
context = llm_result.text
context = cast(str, llm_result.text)
result_data = parser(context) if parser else {}
success = bool(result_data)
if success:
Expand Down
94 changes: 55 additions & 39 deletions paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import urllib.request
from collections.abc import Callable, Sequence
from datetime import datetime
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, BinaryIO, cast
from uuid import UUID, uuid4

from aviary.core import Message
from llmclient import (
Embeddable,
EmbeddingModel,
Expand All @@ -33,7 +33,6 @@
from paperqa.core import llm_parse_json, map_fxn_summary
from paperqa.llms import (
NumpyVectorStore,
PromptRunner,
VectorStore,
)
from paperqa.paths import PAPERQA_DIR
Expand Down Expand Up @@ -294,12 +293,16 @@ async def aadd( # noqa: PLR0912
)
if not texts:
raise ValueError(f"Could not read document {path}. Is it empty?")
result = await llm_model.run_prompt(
prompt=parse_config.citation_prompt,
data={"text": texts[0].text},
system_prompt=None, # skip system because it's too hesitant to answer
data = {"text": texts[0].text}
messages = [
Message(
role="user", content=parse_config.citation_prompt.format(**data)
),
]
result = await llm_model.call_single( # run_prompt is deprecated
messages=messages,
)
citation = result.text
citation = cast(str, result.text)
if (
len(citation) < 3 # noqa: PLR2004
or "Unknown" in citation
Expand All @@ -315,18 +318,23 @@ async def aadd( # noqa: PLR0912
# try to extract DOI / title from the citation
if (doi is title is None) and parse_config.use_doc_details:
# TODO: specify a JSON schema here when many LLM providers support this
result = await llm_model.run_prompt(
prompt=parse_config.structured_citation_prompt,
data={"citation": citation},
system_prompt=None,
data = {"citation": citation}
messages = [
Message(
role="user",
content=parse_config.structured_citation_prompt.format(**data),
),
]
result = await llm_model.call_single( # run_prompt is deprecated
messages=messages,
)
# This code below tries to isolate the JSON
# based on observed messages from LLMs
# it does so by isolating the content between
# the first { and last } in the response.
# Since the anticipated structure should not be nested,
# we don't have to worry about nested curlies.
clean_text = result.text.split("{", 1)[-1].split("}", 1)[0]
clean_text = cast(str, result.text).split("{", 1)[-1].split("}", 1)[0]
clean_text = "{" + clean_text + "}"
try:
citation_json = json.loads(clean_text)
Expand Down Expand Up @@ -609,19 +617,17 @@ async def aget_evidence(
else matches
)

prompt_runner: PromptRunner | None = None
prompt_details = None
if not answer_config.evidence_skip_summary:
if prompt_config.use_json:
prompt_runner = partial(
summary_llm_model.run_prompt,
prompt_details = (
prompt_config.summary_json,
system_prompt=prompt_config.summary_json_system,
prompt_config.summary_json_system,
)
else:
prompt_runner = partial(
summary_llm_model.run_prompt,
prompt_details = (
prompt_config.summary,
system_prompt=prompt_config.system,
prompt_config.system,
)

with set_llm_session_ids(session.id):
Expand All @@ -631,7 +637,8 @@ async def aget_evidence(
map_fxn_summary(
text=m,
question=session.question,
prompt_runner=prompt_runner,
summary_llm_model=summary_llm_model,
prompt_details=prompt_details,
extra_prompt_data={
"summary_length": answer_config.evidence_summary_length,
"citation": f"{m.name}: {m.doc.formatted_citation}",
Expand Down Expand Up @@ -712,12 +719,15 @@ async def aquery( # noqa: PLR0912
pre_str = None
if prompt_config.pre is not None:
with set_llm_session_ids(session.id):
pre = await llm_model.run_prompt(
prompt=prompt_config.pre,
data={"question": session.question},
data = {"question": session.question}
messages = [
Message(role="system", content=prompt_config.system),
Message(role="user", content=prompt_config.pre.format(**data)),
]
pre = await llm_model.call_single( # run_prompt is deprecated
messages=messages,
callbacks=callbacks,
name="pre",
system_prompt=prompt_config.system,
)
session.add_tokens(pre)
pre_str = pre.text
Expand Down Expand Up @@ -766,19 +776,22 @@ async def aquery( # noqa: PLR0912
)
else:
with set_llm_session_ids(session.id):
answer_result = await llm_model.run_prompt(
prompt=prompt_config.qa,
data={
"context": context_str,
"answer_length": answer_config.answer_length,
"question": session.question,
"example_citation": prompt_config.EXAMPLE_CITATION,
},
data = {
"context": context_str,
"answer_length": answer_config.answer_length,
"question": session.question,
"example_citation": prompt_config.EXAMPLE_CITATION,
}
messages = [
Message(role="system", content=prompt_config.system),
Message(role="user", content=prompt_config.qa.format(**data)),
]
answer_result = await llm_model.call_single( # run_prompt is deprecated
messages=messages,
callbacks=callbacks,
name="answer",
system_prompt=prompt_config.system,
)
answer_text = answer_result.text
answer_text = cast(str, answer_result.text)
session.add_tokens(answer_result)
# it still happens
if (ex_citation := prompt_config.EXAMPLE_CITATION) in answer_text:
Expand Down Expand Up @@ -806,14 +819,17 @@ async def aquery( # noqa: PLR0912

if prompt_config.post is not None:
with set_llm_session_ids(session.id):
post = await llm_model.run_prompt(
prompt=prompt_config.post,
data=session.model_dump(),
data = {"question": session.question}
messages = [
Message(role="system", content=prompt_config.system),
Message(role="user", content=prompt_config.post.format(**data)),
]
post = await llm_model.call_single( # is deprecated
messages=messages,
callbacks=callbacks,
name="post",
system_prompt=prompt_config.system,
)
answer_text = post.text
answer_text = cast(str, post.text)
session.add_tokens(post)
formatted_answer = f"Question: {session.question}\n\n{post}\n"
if bib:
Expand Down
7 changes: 0 additions & 7 deletions paperqa/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import (
Awaitable,
Callable,
Iterable,
Sequence,
Expand All @@ -20,7 +19,6 @@
EmbeddingModes,
HybridEmbeddingModel,
LiteLLMEmbeddingModel,
LLMResult,
SentenceTransformerEmbeddingModel,
SparseEmbeddingModel,
)
Expand All @@ -46,11 +44,6 @@
except ImportError:
qdrant_installed = False

PromptRunner = Callable[
[dict, Sequence[Callable[[str], None]] | None, str | None],
Awaitable[LLMResult],
]

logger = logging.getLogger(__name__)


Expand Down
Loading

0 comments on commit 645cbdd

Please sign in to comment.