diff --git a/openhands/resolver/interfaces/github.py b/openhands/resolver/interfaces/github.py index 46cceb68a4f7..d9cd66c53eba 100644 --- a/openhands/resolver/interfaces/github.py +++ b/openhands/resolver/interfaces/github.py @@ -22,28 +22,28 @@ def __init__(self, owner: str, repo: str, token: str, username: str | None = Non self.clone_url = self.get_clone_url() self.headers = self.get_headers() - def set_owner(self, owner: str): + def set_owner(self, owner: str) -> None: self.owner = owner - def get_headers(self): + def get_headers(self) -> dict[str, str]: return { 'Authorization': f'token {self.token}', 'Accept': 'application/vnd.github.v3+json', } - def get_base_url(self): + def get_base_url(self) -> str: return f'https://api.github.com/repos/{self.owner}/{self.repo}' - def get_authorize_url(self): + def get_authorize_url(self) -> str: return f'https://{self.username}:{self.token}@github.com/' - def get_branch_url(self, branch_name: str): + def get_branch_url(self, branch_name: str) -> str: return self.get_base_url() + f'/branches/{branch_name}' - def get_download_url(self): + def get_download_url(self) -> str: return f'{self.base_url}/issues' - def get_clone_url(self): + def get_clone_url(self) -> str: username_and_token = ( f'{self.username}:{self.token}' if self.username @@ -51,10 +51,10 @@ def get_clone_url(self): ) return f'https://{username_and_token}@github.com/{self.owner}/{self.repo}.git' - def get_graphql_url(self): + def get_graphql_url(self) -> str: return 'https://api.github.com/graphql' - def get_compare_url(self, branch_name: str): + def get_compare_url(self, branch_name: str) -> str: return f'https://github.com/{self.owner}/{self.repo}/compare/{branch_name}?expand=1' def get_converted_issues( @@ -186,7 +186,7 @@ def branch_exists(self, branch_name: str) -> bool: print(f'Branch {branch_name} exists: {exists}') return exists - def get_branch_name(self, base_branch_name: str): + def get_branch_name(self, base_branch_name: str) -> str: branch_name = base_branch_name attempt = 1 while self.branch_exists(branch_name): @@ -194,7 +194,7 @@ def get_branch_name(self, base_branch_name: str): branch_name = f'{base_branch_name}-try{attempt}' return branch_name - def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): + def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None: # Opting for graphql as REST API doesn't allow reply to replies in comment threads query = """ mutation($body: String!, $pullRequestReviewThreadId: ID!) { @@ -221,15 +221,18 @@ def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): ) response.raise_for_status() - def get_pull_url(self, pr_number: int): + def get_pull_url(self, pr_number: int) -> str: return f'https://github.com/{self.owner}/{self.repo}/pull/{pr_number}' def get_default_branch_name(self) -> str: response = requests.get(f'{self.base_url}', headers=self.headers) response.raise_for_status() - return response.json()['default_branch'] + data = response.json() + return str(data['default_branch']) - def create_pull_request(self, data=dict) -> dict: + def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]: + if data is None: + data = {} response = requests.post( f'{self.base_url}/pulls', headers=self.headers, json=data ) @@ -240,9 +243,9 @@ def create_pull_request(self, data=dict) -> dict: ) response.raise_for_status() pr_data = response.json() - return pr_data + return dict(pr_data) - def request_reviewers(self, reviewer: str, pr_number: int): + def request_reviewers(self, reviewer: str, pr_number: int) -> None: review_data = {'reviewers': [reviewer]} review_response = requests.post( f'{self.base_url}/pulls/{pr_number}/requested_reviewers', @@ -254,7 +257,7 @@ def request_reviewers(self, reviewer: str, pr_number: int): f'Warning: Failed to request review from {reviewer}: {review_response.text}' ) - def send_comment_msg(self, issue_number: int, msg: str): + def send_comment_msg(self, issue_number: int, msg: str) -> None: """Send a comment message to a GitHub issue or pull request. Args: @@ -282,8 +285,8 @@ def get_context_from_external_issues_references( review_comments: list[str] | None, review_threads: list[ReviewThread], thread_comments: list[str] | None, - ): - pass + ) -> list[str]: + return [] class GithubPRHandler(GithubIssueHandler): @@ -487,7 +490,7 @@ def get_context_from_external_issues_references( review_comments: list[str] | None, review_threads: list[ReviewThread], thread_comments: list[str] | None, - ): + ) -> list[str]: new_issue_references = [] if issue_body: diff --git a/openhands/resolver/interfaces/gitlab.py b/openhands/resolver/interfaces/gitlab.py index 52661d93032d..47519c4a4e0f 100644 --- a/openhands/resolver/interfaces/gitlab.py +++ b/openhands/resolver/interfaces/gitlab.py @@ -23,38 +23,38 @@ def __init__(self, owner: str, repo: str, token: str, username: str | None = Non self.clone_url = self.get_clone_url() self.headers = self.get_headers() - def set_owner(self, owner: str): + def set_owner(self, owner: str) -> None: self.owner = owner - def get_headers(self): + def get_headers(self) -> dict[str, str]: return { 'Authorization': f'Bearer {self.token}', 'Accept': 'application/json', } - def get_base_url(self): - project_path = quote(f'{self.owner}/{self.repo}', safe="") + def get_base_url(self) -> str: + project_path = quote(f'{self.owner}/{self.repo}', safe='') return f'https://gitlab.com/api/v4/projects/{project_path}' - def get_authorize_url(self): + def get_authorize_url(self) -> str: return f'https://{self.username}:{self.token}@gitlab.com/' - def get_branch_url(self, branch_name: str): + def get_branch_url(self, branch_name: str) -> str: return self.get_base_url() + f'/repository/branches/{branch_name}' - def get_download_url(self): + def get_download_url(self) -> str: return f'{self.base_url}/issues' - def get_clone_url(self): + def get_clone_url(self) -> str: username_and_token = self.token if self.username: username_and_token = f'{self.username}:{self.token}' return f'https://{username_and_token}@gitlab.com/{self.owner}/{self.repo}.git' - def get_graphql_url(self): + def get_graphql_url(self) -> str: return 'https://gitlab.com/api/graphql' - def get_compare_url(self, branch_name: str): + def get_compare_url(self, branch_name: str) -> str: return f'https://gitlab.com/{self.owner}/{self.repo}/-/compare/{self.get_default_branch_name()}...{branch_name}' def get_converted_issues( @@ -189,7 +189,7 @@ def branch_exists(self, branch_name: str) -> bool: print(f'Branch {branch_name} exists: {exists}') return exists - def get_branch_name(self, base_branch_name: str): + def get_branch_name(self, base_branch_name: str) -> str: branch_name = base_branch_name attempt = 1 while self.branch_exists(branch_name): @@ -197,7 +197,7 @@ def get_branch_name(self, base_branch_name: str): branch_name = f'{base_branch_name}-try{attempt}' return branch_name - def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): + def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None: response = requests.get( f'{self.base_url}/merge_requests/{pr_number}/discussions/{comment_id.split('/')[-1]}', headers=self.headers, @@ -216,7 +216,7 @@ def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): ) response.raise_for_status() - def get_pull_url(self, pr_number: int): + def get_pull_url(self, pr_number: int) -> str: return ( f'https://gitlab.com/{self.owner}/{self.repo}/-/merge_requests/{pr_number}' ) @@ -224,9 +224,12 @@ def get_pull_url(self, pr_number: int): def get_default_branch_name(self) -> str: response = requests.get(f'{self.base_url}', headers=self.headers) response.raise_for_status() - return response.json()['default_branch'] + data = response.json() + return str(data['default_branch']) - def create_pull_request(self, data=dict) -> dict: + def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]: + if data is None: + data = {} response = requests.post( f'{self.base_url}/merge_requests', headers=self.headers, json=data ) @@ -243,9 +246,9 @@ def create_pull_request(self, data=dict) -> dict: if 'iid' in pr_data: pr_data['number'] = pr_data['iid'] - return pr_data + return dict(pr_data) - def request_reviewers(self, reviewer: str, pr_number: int): + def request_reviewers(self, reviewer: str, pr_number: int) -> None: response = requests.get( f'https://gitlab.com/api/v4/users?username={reviewer}', headers=self.headers, @@ -264,7 +267,7 @@ def request_reviewers(self, reviewer: str, pr_number: int): f'Warning: Failed to request review from {reviewer}: {review_response.text}' ) - def send_comment_msg(self, issue_number: int, msg: str): + def send_comment_msg(self, issue_number: int, msg: str) -> None: """Send a comment message to a GitHub issue or pull request. Args: @@ -292,8 +295,8 @@ def get_context_from_external_issues_references( review_comments: list[str] | None, review_threads: list[ReviewThread], thread_comments: list[str] | None, - ): - pass + ) -> list[str]: + return [] class GitlabPRHandler(GitlabIssueHandler): @@ -479,7 +482,7 @@ def get_context_from_external_issues_references( review_comments: list[str] | None, review_threads: list[ReviewThread], thread_comments: list[str] | None, - ): + ) -> list[str]: new_issue_references = [] if issue_body: diff --git a/openhands/resolver/interfaces/issue.py b/openhands/resolver/interfaces/issue.py index 263fd8160377..ffd6a204ca7a 100644 --- a/openhands/resolver/interfaces/issue.py +++ b/openhands/resolver/interfaces/issue.py @@ -26,7 +26,7 @@ class Issue(BaseModel): class IssueHandlerInterface(ABC): @abstractmethod - def set_owner(self, owner: str): + def set_owner(self, owner: str) -> None: pass @abstractmethod @@ -40,43 +40,43 @@ def get_issue_comments( pass @abstractmethod - def get_base_url(self): + def get_base_url(self) -> str: pass @abstractmethod - def get_branch_url(self, branch_name): + def get_branch_url(self, branch_name: str) -> str: pass @abstractmethod - def get_download_url(self): + def get_download_url(self) -> str: pass @abstractmethod - def get_clone_url(self): + def get_clone_url(self) -> str: pass @abstractmethod - def get_pull_url(self, pr_number: int): + def get_pull_url(self, pr_number: int) -> str: pass @abstractmethod - def get_graphql_url(self): + def get_graphql_url(self) -> str: pass @abstractmethod - def get_headers(self): + def get_headers(self) -> dict[str, str]: pass @abstractmethod - def get_compare_url(self, branch_name): + def get_compare_url(self, branch_name: str) -> str: pass @abstractmethod - def get_branch_name(self, base_branch_name: str): + def get_branch_name(self, base_branch_name: str) -> str: pass @abstractmethod - def get_default_branch_name(self): + def get_default_branch_name(self) -> str: pass @abstractmethod @@ -84,23 +84,25 @@ def branch_exists(self, branch_name: str) -> bool: pass @abstractmethod - def reply_to_comment(self, pr_number: int, comment_id: str, reply: str): + def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None: pass @abstractmethod - def send_comment_msg(self, issue_number: int, msg: str): + def send_comment_msg(self, issue_number: int, msg: str) -> None: pass @abstractmethod - def get_authorize_url(self): + def get_authorize_url(self) -> str: pass @abstractmethod - def create_pull_request(self, data=dict) -> dict: - pass + def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]: + if data is None: + data = {} + raise NotImplementedError @abstractmethod - def request_reviewers(self, reviewer: str, pr_number: int): + def request_reviewers(self, reviewer: str, pr_number: int) -> None: pass @abstractmethod @@ -112,7 +114,7 @@ def get_context_from_external_issues_references( review_comments: list[str] | None, review_threads: list[ReviewThread], thread_comments: list[str] | None, - ): + ) -> list[str]: pass @abstractmethod diff --git a/openhands/resolver/interfaces/issue_definitions.py b/openhands/resolver/interfaces/issue_definitions.py index 6912ab5c1e78..15a9f7cb2d77 100644 --- a/openhands/resolver/interfaces/issue_definitions.py +++ b/openhands/resolver/interfaces/issue_definitions.py @@ -25,7 +25,7 @@ def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None if llm_config is not None: self.llm = LLM(llm_config) - def set_strategy(self, strategy): + def set_strategy(self, strategy: IssueHandlerInterface) -> None: self._strategy = strategy @@ -36,7 +36,7 @@ class ServiceContextPR(ServiceContext): def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig): super().__init__(strategy, llm_config) - def get_clone_url(self): + def get_clone_url(self) -> str: return self._strategy.get_clone_url() def download_issues(self) -> list[Any]: @@ -266,31 +266,31 @@ class ServiceContextIssue(ServiceContext): def __init__(self, strategy: IssueHandlerInterface, llm_config: LLMConfig | None): super().__init__(strategy, llm_config) - def get_base_url(self): + def get_base_url(self) -> str: return self._strategy.get_base_url() - def get_branch_url(self, branch_name): + def get_branch_url(self, branch_name: str) -> str: return self._strategy.get_branch_url(branch_name) - def get_download_url(self): + def get_download_url(self) -> str: return self._strategy.get_download_url() - def get_clone_url(self): + def get_clone_url(self) -> str: return self._strategy.get_clone_url() - def get_graphql_url(self): + def get_graphql_url(self) -> str: return self._strategy.get_graphql_url() - def get_headers(self): + def get_headers(self) -> dict[str, str]: return self._strategy.get_headers() - def get_authorize_url(self): + def get_authorize_url(self) -> str: return self._strategy.get_authorize_url() - def get_pull_url(self, pr_number: int): + def get_pull_url(self, pr_number: int) -> str: return self._strategy.get_pull_url(pr_number) - def get_compare_url(self, branch_name: str): + def get_compare_url(self, branch_name: str) -> str: return self._strategy.get_compare_url(branch_name) def download_issues(self) -> list[Any]: @@ -299,25 +299,27 @@ def download_issues(self) -> list[Any]: def get_branch_name( self, base_branch_name: str, - ): + ) -> str: return self._strategy.get_branch_name(base_branch_name) - def branch_exists(self, branch_name: str): + def branch_exists(self, branch_name: str) -> bool: return self._strategy.branch_exists(branch_name) def get_default_branch_name(self) -> str: return self._strategy.get_default_branch_name() - def create_pull_request(self, data=dict): + def create_pull_request(self, data: dict[str, Any] | None = None) -> dict[str, Any]: + if data is None: + data = {} return self._strategy.create_pull_request(data) - def request_reviewers(self, reviewer: str, pr_number: int): + def request_reviewers(self, reviewer: str, pr_number: int) -> None: return self._strategy.request_reviewers(reviewer, pr_number) - def reply_to_comment(self, pr_number, comment_id, reply): + def reply_to_comment(self, pr_number: int, comment_id: str, reply: str) -> None: return self._strategy.reply_to_comment(pr_number, comment_id, reply) - def send_comment_msg(self, issue_number: int, msg: str): + def send_comment_msg(self, issue_number: int, msg: str) -> None: return self._strategy.send_comment_msg(issue_number, msg) def get_issue_comments( diff --git a/openhands/resolver/patching/apply.py b/openhands/resolver/patching/apply.py index aedc521a1cd6..65db16c21a83 100644 --- a/openhands/resolver/patching/apply.py +++ b/openhands/resolver/patching/apply.py @@ -5,10 +5,13 @@ import tempfile from .exceptions import HunkApplyException, SubprocessException +from .patch import Change, diffobj from .snippets import remove, which -def _apply_diff_with_subprocess(diff, lines, reverse=False): +def _apply_diff_with_subprocess( + diff: diffobj, lines: list[str], reverse: bool = False +) -> tuple[list[str], list[str] | None]: # call out to patch program patchexec = which('patch') if not patchexec: @@ -63,21 +66,21 @@ def _apply_diff_with_subprocess(diff, lines, reverse=False): return lines, rejlines -def _reverse(changes): - def _reverse_change(c): +def _reverse(changes: list[Change]) -> list[Change]: + def _reverse_change(c: Change) -> Change: return c._replace(old=c.new, new=c.old) return [_reverse_change(c) for c in changes] -def apply_diff(diff, text, reverse=False, use_patch=False): - try: - lines = text.splitlines() - except AttributeError: - lines = list(text) +def apply_diff( + diff: diffobj, text: str | list[str], reverse: bool = False, use_patch: bool = False +) -> list[str]: + lines = text.splitlines() if isinstance(text, str) else list(text) if use_patch: - return _apply_diff_with_subprocess(diff, lines, reverse) + lines, _ = _apply_diff_with_subprocess(diff, lines, reverse) + return lines n_lines = len(lines) diff --git a/openhands/resolver/patching/exceptions.py b/openhands/resolver/patching/exceptions.py index 30653c56da18..75159ddf6555 100644 --- a/openhands/resolver/patching/exceptions.py +++ b/openhands/resolver/patching/exceptions.py @@ -1,31 +1,31 @@ -class PatchingException(Exception): - pass - - -class HunkException(PatchingException): - def __init__(self, msg, hunk=None): - self.hunk = hunk - if hunk is not None: - super(HunkException, self).__init__( - '{msg}, in hunk #{n}'.format(msg=msg, n=hunk) - ) - else: - super(HunkException, self).__init__(msg) - - -class ApplyException(PatchingException): - pass - - -class SubprocessException(ApplyException): - def __init__(self, msg, code): - super(SubprocessException, self).__init__(msg) - self.code = code - - -class HunkApplyException(HunkException, ApplyException, ValueError): - pass - - -class ParseException(HunkException, ValueError): - pass +class PatchingException(Exception): + pass + + +class HunkException(PatchingException): + def __init__(self, msg: str, hunk: int | None = None) -> None: + self.hunk = hunk + if hunk is not None: + super(HunkException, self).__init__( + '{msg}, in hunk #{n}'.format(msg=msg, n=hunk) + ) + else: + super(HunkException, self).__init__(msg) + + +class ApplyException(PatchingException): + pass + + +class SubprocessException(ApplyException): + def __init__(self, msg: str, code: int) -> None: + super(SubprocessException, self).__init__(msg) + self.code = code + + +class HunkApplyException(HunkException, ApplyException, ValueError): + pass + + +class ParseException(HunkException, ValueError): + pass diff --git a/openhands/resolver/patching/patch.py b/openhands/resolver/patching/patch.py index 97cb5d488293..4359e732512c 100644 --- a/openhands/resolver/patching/patch.py +++ b/openhands/resolver/patching/patch.py @@ -3,6 +3,7 @@ import re import zlib from collections import namedtuple +from typing import Iterable from . import exceptions from .snippets import findall_regex, split_by_regex @@ -71,11 +72,8 @@ old_cvs_diffcmd_header = re.compile('^diff.* (.+):(.*) (.+):(.*)$') -def parse_patch(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_patch(text: str | list[str]) -> Iterable[diffobj]: + lines = text.splitlines() if isinstance(text, str) else text # maybe use this to nuke all of those line endings? # lines = [x.splitlines()[0] for x in lines] @@ -104,18 +102,15 @@ def parse_patch(text): yield diffobj(header=h, changes=d, text=difftext) -def parse_header(text): +def parse_header(text: str | list[str]) -> header | None: h = parse_scm_header(text) if h is None: h = parse_diff_header(text) return h -def parse_scm_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_scm_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text check = [ (git_header_index, parse_git_header), @@ -154,11 +149,8 @@ def parse_scm_header(text): return None -def parse_diff_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_diff_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text check = [ (unified_header_new_line, parse_unified_header), @@ -178,10 +170,10 @@ def parse_diff_header(text): return None # no header? -def parse_diff(text): - try: +def parse_diff(text: str | list[str]) -> list[Change] | None: + if isinstance(text, str): lines = text.splitlines() - except AttributeError: + else: lines = text check = [ @@ -200,11 +192,8 @@ def parse_diff(text): return None -def parse_git_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_git_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text old_version = None new_version = None @@ -275,11 +264,8 @@ def parse_git_header(text): return None -def parse_svn_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_svn_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text headers = findall_regex(lines, svn_header_index) if len(headers) == 0: @@ -346,11 +332,8 @@ def parse_svn_header(text): return None -def parse_cvs_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_cvs_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text headers = findall_regex(lines, cvs_header_rcs) headers_old = findall_regex(lines, old_cvs_diffcmd_header) @@ -430,11 +413,8 @@ def parse_cvs_header(text): return None -def parse_diffcmd_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_diffcmd_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text headers = findall_regex(lines, diffcmd_header) if len(headers) == 0: @@ -454,11 +434,8 @@ def parse_diffcmd_header(text): return None -def parse_unified_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_unified_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text headers = findall_regex(lines, unified_header_new_line) if len(headers) == 0: @@ -490,11 +467,8 @@ def parse_unified_header(text): return None -def parse_context_header(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_context_header(text: str | list[str]) -> header | None: + lines = text.splitlines() if isinstance(text, str) else text headers = findall_regex(lines, context_header_old_line) if len(headers) == 0: @@ -526,11 +500,8 @@ def parse_context_header(text): return None -def parse_default_diff(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_default_diff(text: str | list[str]) -> list[Change] | None: + lines = text.splitlines() if isinstance(text, str) else text old = 0 new = 0 @@ -582,11 +553,8 @@ def parse_default_diff(text): return None -def parse_unified_diff(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_unified_diff(text: str | list[str]) -> list[Change] | None: + lines = text.splitlines() if isinstance(text, str) else text old = 0 new = 0 @@ -652,11 +620,8 @@ def parse_unified_diff(text): return None -def parse_context_diff(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_context_diff(text: str | list[str]) -> list[Change] | None: + lines = text.splitlines() if isinstance(text, str) else text old = 0 new = 0 @@ -795,11 +760,8 @@ def parse_context_diff(text): return None -def parse_ed_diff(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_ed_diff(text: str | list[str]) -> list[Change] | None: + lines = text.splitlines() if isinstance(text, str) else text old = 0 j = 0 @@ -878,12 +840,9 @@ def parse_ed_diff(text): return None -def parse_rcs_ed_diff(text): +def parse_rcs_ed_diff(text: str | list[str]) -> list[Change] | None: # much like forward ed, but no 'c' type - try: - lines = text.splitlines() - except AttributeError: - lines = text + lines = text.splitlines() if isinstance(text, str) else text old = 0 j = 0 @@ -905,7 +864,7 @@ def parse_rcs_ed_diff(text): hunk_kind = o.group(1) old = int(o.group(2)) - size = int(o.group(3)) + size = int(o.group(3)) if o.group(3) else 0 if hunk_kind == 'a': old += total_change_size + 1 @@ -926,15 +885,11 @@ def parse_rcs_ed_diff(text): if len(changes) > 0: return changes - return None -def parse_git_binary_diff(text): - try: - lines = text.splitlines() - except AttributeError: - lines = text +def parse_git_binary_diff(text: str | list[str]) -> list[Change] | None: + lines = text.splitlines() if isinstance(text, str) else text changes: list[Change] = list() diff --git a/openhands/resolver/patching/snippets.py b/openhands/resolver/patching/snippets.py index f9d9e620d0f7..1c90f573df15 100644 --- a/openhands/resolver/patching/snippets.py +++ b/openhands/resolver/patching/snippets.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- import os +import re from shutil import rmtree -def remove(path): +def remove(path: str) -> None: if os.path.exists(path): if os.path.isdir(path): rmtree(path) @@ -13,7 +14,7 @@ def remove(path): # find all indices of a list of strings that match a regex -def findall_regex(items, regex): +def findall_regex(items: list[str], regex: re.Pattern[str]) -> list[int]: found = list() for i in range(0, len(items)): k = regex.match(items[i]) @@ -24,7 +25,7 @@ def findall_regex(items, regex): return found -def split_by_regex(items, regex): +def split_by_regex(items: list[str], regex: re.Pattern[str]) -> list[list[str]]: splits = list() indices = findall_regex(items, regex) if not indices: @@ -45,8 +46,8 @@ def split_by_regex(items, regex): # http://stackoverflow.com/questions/377017/test-if-executable-exists-in-python -def which(program): - def is_exe(fpath): +def which(program: str) -> str | None: + def is_exe(fpath: str) -> bool: return os.path.isfile(fpath) and os.access(fpath, os.X_OK) fpath, fname = os.path.split(program) diff --git a/openhands/resolver/resolve_all_issues.py b/openhands/resolver/resolve_all_issues.py index 6aa32396545d..7696c06d2ad2 100644 --- a/openhands/resolver/resolve_all_issues.py +++ b/openhands/resolver/resolve_all_issues.py @@ -6,8 +6,9 @@ import os import pathlib import subprocess -from typing import Awaitable, TextIO +from typing import Any, Awaitable, TextIO +from pydantic import SecretStr from tqdm import tqdm import openhands @@ -25,7 +26,7 @@ ) -def cleanup(): +def cleanup() -> None: print('Cleaning up child processes...') for process in mp.active_children(): print(f'Terminating child process: {process.name}') @@ -214,7 +215,7 @@ async def resolve_issues( # Use asyncio.gather with a semaphore to limit concurrency sem = asyncio.Semaphore(num_workers) - async def run_with_semaphore(task): + async def run_with_semaphore(task: Awaitable[Any]) -> Any: async with sem: return await task @@ -228,7 +229,7 @@ async def run_with_semaphore(task): logger.info('Finished.') -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Resolve multiple issues from Github or Gitlab.' ) @@ -349,7 +350,7 @@ def main(): llm_config = LLMConfig( model=my_args.llm_model or os.environ['LLM_MODEL'], - api_key=str(api_key) if api_key else None, + api_key=SecretStr(api_key) if api_key else None, base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None), ) diff --git a/openhands/resolver/resolve_issue.py b/openhands/resolver/resolve_issue.py index 80cddb9ed581..9a686815da38 100644 --- a/openhands/resolver/resolve_issue.py +++ b/openhands/resolver/resolve_issue.py @@ -10,6 +10,7 @@ from typing import Any from uuid import uuid4 +from pydantic import SecretStr from termcolor import colored import openhands @@ -18,6 +19,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.core.main import create_runtime, run_controller from openhands.events.action import CmdRunAction, MessageAction +from openhands.events.event import Event from openhands.events.observation import ( CmdOutputObservation, ErrorObservation, @@ -48,7 +50,7 @@ def initialize_runtime( runtime: Runtime, platform: Platform, -): +) -> None: """Initialize the runtime for the agent. This function is called before the runtime is used to run the agent. @@ -192,26 +194,28 @@ async def process_issue( # This code looks unnecessary because these are default values in the config class # they're set by default if nothing else overrides them # FIXME we should remove them here - kwargs = {} + sandbox_config = SandboxConfig( + runtime_container_image=runtime_container_image, + enable_auto_lint=False, + use_host_network=False, + # large enough timeout, since some testcases take very long to run + timeout=300, + ) + if os.getenv('GITLAB_CI') == 'True': - kwargs['local_runtime_url'] = os.getenv('LOCAL_RUNTIME_URL', 'http://localhost') + sandbox_config.local_runtime_url = os.getenv( + 'LOCAL_RUNTIME_URL', 'http://localhost' + ) user_id = os.getuid() if hasattr(os, 'getuid') else 1000 if user_id == 0: - kwargs['user_id'] = get_unique_uid() + sandbox_config.user_id = get_unique_uid() config = AppConfig( default_agent='CodeActAgent', runtime='docker', max_budget_per_task=4, max_iterations=max_iterations, - sandbox=SandboxConfig( - runtime_container_image=runtime_container_image, - enable_auto_lint=False, - use_host_network=False, - # large enough timeout, since some testcases take very long to run - timeout=300, - **kwargs, - ), + sandbox=sandbox_config, # do not mount workspace workspace_base=workspace_base, workspace_mount_path=workspace_base, @@ -222,7 +226,7 @@ async def process_issue( runtime = create_runtime(config) await runtime.connect() - def on_event(evt): + def on_event(evt: Event) -> None: logger.info(evt) runtime.event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4())) @@ -524,10 +528,10 @@ async def resolve_issue( logger.info('Finished.') -def main(): +def main() -> None: import argparse - def int_or_none(value): + def int_or_none(value: str) -> int | None: if value.lower() == 'none': return None else: @@ -654,7 +658,7 @@ def int_or_none(value): api_key = my_args.llm_api_key or os.environ['LLM_API_KEY'] llm_config = LLMConfig( model=my_args.llm_model or os.environ['LLM_MODEL'], - api_key=str(api_key) if api_key else None, + api_key=SecretStr(api_key) if api_key else None, base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None), ) diff --git a/openhands/resolver/send_pull_request.py b/openhands/resolver/send_pull_request.py index 7cbe37cfcca0..0ca032e4da97 100644 --- a/openhands/resolver/send_pull_request.py +++ b/openhands/resolver/send_pull_request.py @@ -5,6 +5,7 @@ import subprocess import jinja2 +from pydantic import SecretStr from openhands.core.config import LLMConfig from openhands.core.logger import openhands_logger as logger @@ -543,7 +544,7 @@ def process_all_successful_issues( ) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Send a pull request to Github or Gitlab.' ) @@ -641,7 +642,7 @@ def main(): api_key = my_args.llm_api_key or os.environ['LLM_API_KEY'] llm_config = LLMConfig( model=my_args.llm_model or os.environ['LLM_MODEL'], - api_key=str(api_key) if api_key else None, + api_key=SecretStr(api_key) if api_key else None, base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None), ) diff --git a/openhands/resolver/utils.py b/openhands/resolver/utils.py index b0e25861ccb7..d08686d592e3 100644 --- a/openhands/resolver/utils.py +++ b/openhands/resolver/utils.py @@ -107,7 +107,7 @@ def codeact_user_response( return msg -def cleanup(): +def cleanup() -> None: print('Cleaning up child processes...') for process in mp.active_children(): print(f'Terminating child process: {process.name}') @@ -115,7 +115,9 @@ def cleanup(): process.join() -def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int): +def prepare_dataset( + dataset: pd.DataFrame, output_file: str, eval_n_limit: int +) -> pd.DataFrame: assert 'instance_id' in dataset.columns, ( "Expected 'instance_id' column in the dataset. You should define your own " "unique identifier for each instance and use it as the 'instance_id' column." @@ -152,7 +154,7 @@ def prepare_dataset(dataset: pd.DataFrame, output_file: str, eval_n_limit: int): def reset_logger_for_multiprocessing( logger: logging.Logger, instance_id: str, log_dir: str -): +) -> None: """Reset the logger for multiprocessing. Save logs to a separate file for each process, instead of trying to write to the @@ -208,7 +210,7 @@ def extract_issue_references(body: str) -> list[int]: return [int(match) for match in re.findall(pattern, body)] -def get_unique_uid(start_uid=1000): +def get_unique_uid(start_uid: int = 1000) -> int: existing_uids = set() with open('/etc/passwd', 'r') as passwd_file: for line in passwd_file: diff --git a/openhands/resolver/visualize_resolver_output.py b/openhands/resolver/visualize_resolver_output.py index f7081f6c76f4..ca0717d84088 100644 --- a/openhands/resolver/visualize_resolver_output.py +++ b/openhands/resolver/visualize_resolver_output.py @@ -4,7 +4,9 @@ from openhands.resolver.io_utils import load_single_resolver_output -def visualize_resolver_output(issue_number: int, output_dir: str, vis_method: str): +def visualize_resolver_output( + issue_number: int, output_dir: str, vis_method: str +) -> None: output_jsonl = os.path.join(output_dir, 'output.jsonl') resolver_output = load_single_resolver_output(output_jsonl, issue_number) if vis_method == 'json':