Skip to content

Commit

Permalink
Fix type checking errors in resolver directory (part 2)
Browse files Browse the repository at this point in the history
  • Loading branch information
openhands-agent committed Feb 18, 2025
1 parent 4d9faf6 commit 305a4a8
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 28 deletions.
11 changes: 6 additions & 5 deletions openhands/resolver/resolve_all_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import os
import pathlib
import subprocess
from typing import Awaitable, TextIO
from typing import Any, Awaitable, TextIO

from pydantic import SecretStr
from tqdm import tqdm

import openhands
Expand All @@ -25,7 +26,7 @@
)


def cleanup():
def cleanup() -> None:
print('Cleaning up child processes...')
for process in mp.active_children():
print(f'Terminating child process: {process.name}')
Expand Down Expand Up @@ -214,7 +215,7 @@ async def resolve_issues(
# Use asyncio.gather with a semaphore to limit concurrency
sem = asyncio.Semaphore(num_workers)

async def run_with_semaphore(task):
async def run_with_semaphore(task: Awaitable[Any]) -> Any:
async with sem:
return await task

Expand All @@ -228,7 +229,7 @@ async def run_with_semaphore(task):
logger.info('Finished.')


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Resolve multiple issues from Github or Gitlab.'
)
Expand Down Expand Up @@ -349,7 +350,7 @@ def main():

llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None,
api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
)

Expand Down
36 changes: 20 additions & 16 deletions openhands/resolver/resolve_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any
from uuid import uuid4

from pydantic import SecretStr
from termcolor import colored

import openhands
Expand All @@ -18,6 +19,7 @@
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.event import Event
from openhands.events.observation import (
CmdOutputObservation,
ErrorObservation,
Expand Down Expand Up @@ -48,7 +50,7 @@
def initialize_runtime(
runtime: Runtime,
platform: Platform,
):
) -> None:
"""Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent.
Expand Down Expand Up @@ -192,26 +194,28 @@ async def process_issue(
# This code looks unnecessary because these are default values in the config class
# they're set by default if nothing else overrides them
# FIXME we should remove them here
kwargs = {}
sandbox_config = SandboxConfig(
runtime_container_image=runtime_container_image,
enable_auto_lint=False,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
)

if os.getenv('GITLAB_CI') == 'True':
kwargs['local_runtime_url'] = os.getenv('LOCAL_RUNTIME_URL', 'http://localhost')
sandbox_config.local_runtime_url = os.getenv(
'LOCAL_RUNTIME_URL', 'http://localhost'
)
user_id = os.getuid() if hasattr(os, 'getuid') else 1000
if user_id == 0:
kwargs['user_id'] = get_unique_uid()
sandbox_config.user_id = get_unique_uid()

config = AppConfig(
default_agent='CodeActAgent',
runtime='docker',
max_budget_per_task=4,
max_iterations=max_iterations,
sandbox=SandboxConfig(
runtime_container_image=runtime_container_image,
enable_auto_lint=False,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
**kwargs,
),
sandbox=sandbox_config,
# do not mount workspace
workspace_base=workspace_base,
workspace_mount_path=workspace_base,
Expand All @@ -222,7 +226,7 @@ async def process_issue(
runtime = create_runtime(config)
await runtime.connect()

def on_event(evt):
def on_event(evt: Event) -> None:
logger.info(evt)

runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4()))
Expand Down Expand Up @@ -524,10 +528,10 @@ async def resolve_issue(
logger.info('Finished.')


def main():
def main() -> None:
import argparse

def int_or_none(value):
def int_or_none(value: str) -> int | None:
if value.lower() == 'none':
return None
else:
Expand Down Expand Up @@ -654,7 +658,7 @@ def int_or_none(value):
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None,
api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
)

Expand Down
72 changes: 70 additions & 2 deletions openhands/resolver/send_pull_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import subprocess

import jinja2
import requests
from pydantic import SecretStr

from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
Expand Down Expand Up @@ -346,6 +348,72 @@ def send_pull_request(
return url


def reply_to_comment(github_token: str, comment_id: str, reply: str) -> None:
"""Reply to a comment on a GitHub issue or pull request.
Args:
github_token: The GitHub token to use for authentication
comment_id: The ID of the comment to reply to
reply: The reply message to post
"""
# Opting for graphql as REST API doesn't allow reply to replies in comment threads
query = """
mutation($body: String!, $pullRequestReviewThreadId: ID!) {
addPullRequestReviewThreadReply(input: { body: $body, pullRequestReviewThreadId: $pullRequestReviewThreadId }) {
comment {
id
body
createdAt
}
}
}
"""

# Prepare the reply to the comment
comment_reply = f'Openhands fix success summary\n\n\n{reply}'
variables = {'body': comment_reply, 'pullRequestReviewThreadId': comment_id}
url = 'https://api.github.com/graphql'
headers = {
'Authorization': f'Bearer {github_token}',
'Content-Type': 'application/json',
}

# Send the reply to the comment
response = requests.post(
url, json={'query': query, 'variables': variables}, headers=headers
)
response.raise_for_status()


def send_comment_msg(
base_url: str, issue_number: int, github_token: str, msg: str
) -> None:
"""Send a comment message to a GitHub issue or pull request.
Args:
base_url: The base URL of the GitHub repository API
issue_number: The issue or pull request number
github_token: The GitHub token to use for authentication
msg: The message content to post as a comment
"""
# Set up headers for GitHub API
headers = {
'Authorization': f'token {github_token}',
'Accept': 'application/vnd.github.v3+json',
}

# Post a comment on the PR
comment_url = f'{base_url}/issues/{issue_number}/comments'
comment_data = {'body': msg}
comment_response = requests.post(comment_url, headers=headers, json=comment_data)
if comment_response.status_code != 201:
print(
f'Failed to post comment: {comment_response.status_code} {comment_response.text}'
)
else:
print(f'Comment added to the PR: {msg}')


def update_existing_pull_request(
issue: Issue,
token: str,
Expand Down Expand Up @@ -543,7 +611,7 @@ def process_all_successful_issues(
)


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Send a pull request to Github or Gitlab.'
)
Expand Down Expand Up @@ -641,7 +709,7 @@ def main():
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig(
model=my_args.llm_model or os.environ['LLM_MODEL'],
api_key=str(api_key) if api_key else None,
api_key=SecretStr(api_key) if api_key else None,
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
)

Expand Down
10 changes: 6 additions & 4 deletions openhands/resolver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,17 @@ def codeact_user_response(
return msg


def cleanup():
def cleanup() -> None:
print('Cleaning up child processes...')
for process in mp.active_children():
print(f'Terminating child process: {process.name}')
process.terminate()
process.join()


def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):
def prepare_dataset(
dataset: pd.DataFrame, output_file: str, eval_n_limit: int
) -> pd.DataFrame:
assert 'instance_id' in dataset.columns, (
"Expected 'instance_id' column in the dataset. You should define your own "
"unique identifier for each instance and use it as the 'instance_id' column."
Expand Down Expand Up @@ -152,7 +154,7 @@ def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int):

def reset_logger_for_multiprocessing(
logger: logging.Logger, instance_id: str, log_dir: str
):
) -> None:
"""Reset the logger for multiprocessing.
Save logs to a separate file for each process, instead of trying to write to the
Expand Down Expand Up @@ -208,7 +210,7 @@ def extract_issue_references(body: str) -> list[int]:
return [int(match) for match in re.findall(pattern, body)]


def get_unique_uid(start_uid=1000):
def get_unique_uid(start_uid: int = 1000) -> int:
existing_uids = set()
with open('/etc/passwd', 'r') as passwd_file:
for line in passwd_file:
Expand Down
4 changes: 3 additions & 1 deletion openhands/resolver/visualize_resolver_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from openhands.resolver.io_utils import load_single_resolver_output


def visualize_resolver_output(issue_number: int, output_dir: str, vis_method: str):
def visualize_resolver_output(
issue_number: int, output_dir: str, vis_method: str
) -> None:
output_jsonl = os.path.join(output_dir, 'output.jsonl')
resolver_output = load_single_resolver_output(output_jsonl, issue_number)
if vis_method == 'json':
Expand Down

0 comments on commit 305a4a8

Please sign in to comment.