Skip to content

Commit

Permalink
Empty tool calls leading to done condition (#671)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Nov 9, 2024
1 parent 1af325c commit db1cc76
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 54 deletions.
29 changes: 20 additions & 9 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from copy import deepcopy
from typing import Any, Self, cast
from typing import Any, ClassVar, Self, cast

from aviary.core import (
Environment,
Frame,
Message,
Messages,
Tool,
ToolRequestMessage,
ToolResponseMessage,
Expand Down Expand Up @@ -169,30 +170,40 @@ def _has_excess_answer_failures(self) -> bool:
> self._query.settings.answer.max_answer_attempts
)

USE_POST_PROCESSED_REWARD: ClassVar[float] = 0.0

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
) -> tuple[Messages, float, bool, bool]:
self.state.record_action(action)
if not action.tool_calls:
return (
# NOTE: don't put:
# - GenerateAnswer.FAILED_TO_ANSWER here because this wasn't a failure
# - 'cannot answer' because that information belongs in
# PQASession.answer, not in the message history
# Let's just put a nice message about being done :)
[Message(content="Agent specified 0 tool calls, which means done.")],
self.USE_POST_PROCESSED_REWARD,
True, # Matching LangChain: https://github.com/langchain-ai/langchain/blob/langchain%3D%3D0.2.17/libs/langchain/langchain/agents/output_parsers/openai_functions.py#L38-L77
False, # Let caller determine truncations
)

response_messages = cast(
list[Message],
await self.exec_tool_calls(
action,
state=self.state,
handle_tool_exc=True,
),
await self.exec_tool_calls(action, state=self.state, handle_tool_exc=True),
)
return (
response_messages,
0, # Reward is computed in post-processing, use 0 as a placeholder
self.USE_POST_PROCESSED_REWARD,
any(
isinstance(msg, ToolResponseMessage)
and msg.name == GenerateAnswer.gen_answer.__name__
and GenerateAnswer.did_not_fail_to_answer(msg.content)
for msg in response_messages
)
or self._has_excess_answer_failures(),
False,
False, # Let caller determine truncations
)

def __deepcopy__(self, memo) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aviary.core import (
TASK_DATASET_REGISTRY,
Frame,
Message,
Messages,
TaskDataset,
ToolRequestMessage,
ToolResponseMessage,
Expand Down Expand Up @@ -126,7 +126,7 @@ async def validate_sources(

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
) -> tuple[Messages, float, bool, bool]:
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
Expand Down
6 changes: 4 additions & 2 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,13 @@ class GenerateAnswer(NamedTool):
summary_llm_model: LiteLLMModel
embedding_model: EmbeddingModel

# This is not an answer to assign to the current PQASession,
# but a status for the agent message history
FAILED_TO_ANSWER: ClassVar[str] = "Failed to answer question."

@classmethod
def did_not_fail_to_answer(cls, message: str) -> bool:
return not message.startswith(cls.FAILED_TO_ANSWER)
def did_not_fail_to_answer(cls, message: str | None) -> bool:
return not (message or "").startswith(cls.FAILED_TO_ANSWER)

async def gen_answer(self, question: str, state: EnvironmentState) -> str:
"""
Expand Down
6 changes: 4 additions & 2 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,10 @@ class AgentSettings(BaseModel):
agent_prompt: str = (
"Use the tools to answer the question: {question}\n\nThe {gen_answer_tool_name}"
" tool output is visible to the user, so you do not need to restate the answer"
" and can simply terminate if the answer looks sufficient. The current status"
" of evidence/papers/cost is {status}"
" and can simply terminate if the answer looks sufficient. If the answer does"
" not look sufficient, and you have already tried to answer several times, you"
" can terminate the question by specifying 0 tool calls."
" The current status of evidence/papers/cost is {status}"
)
return_paper_metadata: bool = Field(
default=False,
Expand Down
98 changes: 59 additions & 39 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,46 +744,66 @@ def test_agent_prompt_collection_validations(
AgentSettings(**kwargs)


@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"])
@pytest.mark.asyncio
async def test_deepcopy_env(agent_test_settings: Settings) -> None:
await get_directory_index(settings=agent_test_settings) # Trigger build
class TestGradablePaperQAEnvironment:
@pytest.mark.flaky(reruns=2, only_rerun=["AssertionError"])
@pytest.mark.asyncio
async def test_deepcopy_env(self, agent_test_settings: Settings) -> None:
await get_directory_index(settings=agent_test_settings) # Trigger build

question = "How can you use XAI for chemical property prediction?"
env = GradablePaperQAEnvironment(
query=QueryRequest(query=question, settings=agent_test_settings),
docs=Docs(),
)

question = "How can you use XAI for chemical property prediction?"
env = GradablePaperQAEnvironment(
query=QueryRequest(query=question, settings=agent_test_settings),
docs=Docs(),
)
# 1. Rollout until after gather evidence
await env.reset()
for tool_call in (
ToolCall.from_name(
"paper_search",
query="XAI for chemical property prediction",
min_year=2018,
max_year=2024,
),
ToolCall.from_name("gather_evidence", question=question),
):
await env.step(ToolRequestMessage(tool_calls=[tool_call]))

# 1. Rollout until after gather evidence
await env.reset()
for tool_call in (
ToolCall.from_name(
"paper_search",
query="XAI for chemical property prediction",
min_year=2018,
max_year=2024,
),
ToolCall.from_name("gather_evidence", question=question),
):
await env.step(ToolRequestMessage(tool_calls=[tool_call]))
# 2. Now we deepcopy the environment
env_copy = deepcopy(env)
assert env.state == env_copy.state

# 2. Now we deepcopy the environment
env_copy = deepcopy(env)
assert env.state == env_copy.state
# 3. Generate an answer for both, and confirm they are identical
gen_answer_action = ToolRequestMessage(
tool_calls=[ToolCall.from_name("gen_answer", question=question)]
)
_, _, done, _ = await env.step(gen_answer_action)
assert done
assert not env.state.session.could_not_answer
assert env.state.session.used_contexts
_, _, done, _ = await env_copy.step(gen_answer_action)
assert done
assert not env_copy.state.session.could_not_answer
assert env_copy.state.session.used_contexts
assert sorted(env.state.session.used_contexts) == sorted(
env_copy.state.session.used_contexts
)

# 3. Generate an answer for both, and confirm they are identical
gen_answer_action = ToolRequestMessage(
tool_calls=[ToolCall.from_name("gen_answer", question=question)]
)
_, _, done, _ = await env.step(gen_answer_action)
assert done
assert not env.state.session.could_not_answer
assert env.state.session.used_contexts
_, _, done, _ = await env_copy.step(gen_answer_action)
assert done
assert not env_copy.state.session.could_not_answer
assert env_copy.state.session.used_contexts
assert sorted(env.state.session.used_contexts) == sorted(
env_copy.state.session.used_contexts
)
@pytest.mark.asyncio
async def test_empty_tool_calls(self, agent_test_settings: Settings) -> None:
env = GradablePaperQAEnvironment(
query=QueryRequest(
query="How can you use XAI for chemical property prediction?",
settings=agent_test_settings,
),
docs=Docs(),
)

await env.reset()
obs, _, done, truncated = await env.step(ToolRequestMessage())
assert len(obs) == 1
assert obs[0].content
assert GenerateAnswer.did_not_fail_to_answer(obs[0].content)
assert "0 tool calls" in obs[0].content
assert done
assert not truncated

0 comments on commit db1cc76

Please sign in to comment.