Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Resolver] API Retry on guess success #5187

Merged
merged 9 commits into from
Nov 30, 2024
49 changes: 18 additions & 31 deletions openhands/resolver/issue_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from typing import Any, ClassVar

import jinja2
import litellm
import requests

from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.resolver.github_issue import GithubIssue, ReviewThread


class IssueHandlerInterface(ABC):
issue_type: ClassVar[str]
llm: LLM

@abstractmethod
def get_converted_issues(
Expand All @@ -36,7 +37,7 @@ def get_instruction(

@abstractmethod
def guess_success(
self, issue: GithubIssue, history: list[Event], llm_config: LLMConfig
self, issue: GithubIssue, history: list[Event]
) -> tuple[bool, list[bool] | None, str]:
"""Guess if the issue has been resolved based on the agent's output."""
pass
Expand All @@ -45,11 +46,12 @@ def guess_success(
class IssueHandler(IssueHandlerInterface):
issue_type: ClassVar[str] = 'issue'

def __init__(self, owner: str, repo: str, token: str):
def __init__(self, owner: str, repo: str, token: str, llm_config: LLMConfig):
self.download_url = 'https://api.github.com/repos/{}/{}/issues'
self.owner = owner
self.repo = repo
self.token = token
self.llm = LLM(llm_config)

def _download_issues_from_github(self) -> list[Any]:
url = self.download_url.format(self.owner, self.repo)
Expand Down Expand Up @@ -218,7 +220,7 @@ def get_instruction(
)

def guess_success(
self, issue: GithubIssue, history: list[Event], llm_config: LLMConfig
self, issue: GithubIssue, history: list[Event]
) -> tuple[bool, None | list[bool], str]:
"""Guess if the issue is fixed based on the history and the issue description."""
last_message = history[-1].message
Expand All @@ -239,12 +241,7 @@ def guess_success(
template = jinja2.Template(f.read())
prompt = template.render(issue_context=issue_context, last_message=last_message)

response = litellm.completion(
model=llm_config.model,
messages=[{'role': 'user', 'content': prompt}],
api_key=llm_config.api_key,
base_url=llm_config.base_url,
)
response = self.llm.completion(messages=[{'role': 'user', 'content': prompt}])

answer = response.choices[0].message.content.strip()
pattern = r'--- success\n*(true|false)\n*--- explanation*\n((?:.|\n)*)'
Expand All @@ -258,8 +255,8 @@ def guess_success(
class PRHandler(IssueHandler):
issue_type: ClassVar[str] = 'pr'

def __init__(self, owner: str, repo: str, token: str):
super().__init__(owner, repo, token)
def __init__(self, owner: str, repo: str, token: str, llm_config: LLMConfig):
super().__init__(owner, repo, token, llm_config)
self.download_url = 'https://api.github.com/repos/{}/{}/pulls'

def __download_pr_metadata(
Expand Down Expand Up @@ -612,16 +609,9 @@ def get_instruction(
)
return instruction, images

def _check_feedback_with_llm(
self, prompt: str, llm_config: LLMConfig
) -> tuple[bool, str]:
def _check_feedback_with_llm(self, prompt: str) -> tuple[bool, str]:
"""Helper function to check feedback with LLM and parse response."""
response = litellm.completion(
model=llm_config.model,
messages=[{'role': 'user', 'content': prompt}],
api_key=llm_config.api_key,
base_url=llm_config.base_url,
)
response = self.llm.completion(messages=[{'role': 'user', 'content': prompt}])

answer = response.choices[0].message.content.strip()
pattern = r'--- success\n*(true|false)\n*--- explanation*\n((?:.|\n)*)'
Expand All @@ -635,7 +625,6 @@ def _check_review_thread(
review_thread: ReviewThread,
issues_context: str,
last_message: str,
llm_config: LLMConfig,
) -> tuple[bool, str]:
"""Check if a review thread's feedback has been addressed."""
files_context = json.dumps(review_thread.files, indent=4)
Expand All @@ -656,14 +645,13 @@ def _check_review_thread(
last_message=last_message,
)

return self._check_feedback_with_llm(prompt, llm_config)
return self._check_feedback_with_llm(prompt)

def _check_thread_comments(
self,
thread_comments: list[str],
issues_context: str,
last_message: str,
llm_config: LLMConfig,
) -> tuple[bool, str]:
"""Check if thread comments feedback has been addressed."""
thread_context = '\n---\n'.join(thread_comments)
Expand All @@ -682,14 +670,13 @@ def _check_thread_comments(
last_message=last_message,
)

return self._check_feedback_with_llm(prompt, llm_config)
return self._check_feedback_with_llm(prompt)

def _check_review_comments(
self,
review_comments: list[str],
issues_context: str,
last_message: str,
llm_config: LLMConfig,
) -> tuple[bool, str]:
"""Check if review comments feedback has been addressed."""
review_context = '\n---\n'.join(review_comments)
Expand All @@ -708,10 +695,10 @@ def _check_review_comments(
last_message=last_message,
)

return self._check_feedback_with_llm(prompt, llm_config)
return self._check_feedback_with_llm(prompt)

def guess_success(
self, issue: GithubIssue, history: list[Event], llm_config: LLMConfig
self, issue: GithubIssue, history: list[Event]
) -> tuple[bool, None | list[bool], str]:
"""Guess if the issue is fixed based on the history and the issue description."""
last_message = history[-1].message
Expand All @@ -724,7 +711,7 @@ def guess_success(
for review_thread in issue.review_threads:
if issues_context and last_message:
success, explanation = self._check_review_thread(
review_thread, issues_context, last_message, llm_config
review_thread, issues_context, last_message
)
else:
success, explanation = False, 'Missing context or message'
Expand All @@ -734,7 +721,7 @@ def guess_success(
elif issue.thread_comments:
if issue.thread_comments and issues_context and last_message:
success, explanation = self._check_thread_comments(
issue.thread_comments, issues_context, last_message, llm_config
issue.thread_comments, issues_context, last_message
)
else:
success, explanation = (
Expand All @@ -747,7 +734,7 @@ def guess_success(
# Handle PRs with only review comments (no file-specific review comments or thread comments)
if issue.review_comments and issues_context and last_message:
success, explanation = self._check_review_comments(
issue.review_comments, issues_context, last_message, llm_config
issue.review_comments, issues_context, last_message
)
else:
success, explanation = (
Expand Down
2 changes: 1 addition & 1 deletion openhands/resolver/resolve_all_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def resolve_issues(
repo_instruction: Repository instruction to use.
issue_numbers: List of issue numbers to resolve.
"""
issue_handler = issue_handler_factory(issue_type, owner, repo, token)
issue_handler = issue_handler_factory(issue_type, owner, repo, token, llm_config)

# Load dataset
issues: list[GithubIssue] = issue_handler.get_converted_issues(
Expand Down
10 changes: 5 additions & 5 deletions openhands/resolver/resolve_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def on_event(evt):
metrics = state.metrics.get() if state.metrics else None
# determine success based on the history and the issue description
success, comment_success, success_explanation = issue_handler.guess_success(
issue, state.history, llm_config
issue, state.history
)

if issue_handler.issue_type == 'pr' and comment_success:
Expand Down Expand Up @@ -291,12 +291,12 @@ async def on_event(evt):


def issue_handler_factory(
issue_type: str, owner: str, repo: str, token: str
issue_type: str, owner: str, repo: str, token: str, llm_config: LLMConfig
) -> IssueHandlerInterface:
if issue_type == 'issue':
return IssueHandler(owner, repo, token)
return IssueHandler(owner, repo, token, llm_config)
elif issue_type == 'pr':
return PRHandler(owner, repo, token)
return PRHandler(owner, repo, token, llm_config)
else:
raise ValueError(f'Invalid issue type: {issue_type}')

Expand Down Expand Up @@ -337,7 +337,7 @@ async def resolve_issue(
target_branch: Optional target branch to create PR against (for PRs).
reset_logger: Whether to reset the logger for multiprocessing.
"""
issue_handler = issue_handler_factory(issue_type, owner, repo, token)
issue_handler = issue_handler_factory(issue_type, owner, repo, token, llm_config)

# Load dataset
issues: list[GithubIssue] = issue_handler.get_converted_issues(
Expand Down
Loading
Loading