Skip to content

Commit

Permalink
fix repo_slug parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git committed Sep 17, 2024
1 parent 85321d7 commit 6a3f8d7
Show file tree
Hide file tree
Showing 9 changed files with 834 additions and 799 deletions.
30 changes: 15 additions & 15 deletions patchwork/common/client/llm/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing_extensions import Any, Dict, Iterable, List, Optional, Union

from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
from patchwork.common.client.llm.utils import base_model_to_schema, json_schema_to_model
from patchwork.common.client.llm.utils import json_schema_to_model


@functools.lru_cache
Expand Down Expand Up @@ -52,20 +52,20 @@ def is_model_supported(self, model: str) -> bool:
return model in self.get_models()

def chat_completion(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
generation_dict = dict(
stop_sequences=[stop] if isinstance(stop, str) else stop,
Expand Down
100 changes: 47 additions & 53 deletions patchwork/common/client/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,17 @@
from github import Auth, Consts, Github, GithubException, PullRequest
from gitlab import Gitlab, GitlabAuthenticationError, GitlabError
from gitlab.v4.objects import ProjectMergeRequest
from giturlparse import parse, GitUrlParsed
from typing_extensions import Protocol, TypedDict

from patchwork.logger import logger


def get_slug_from_remote_url(remote_url: str) -> str:
# TODO: consider using https://github.com/nephila/giturlparse instead
if remote_url.startswith("git@"):
# ssh
_, _, potential_slug = remote_url.partition(":")
else:
potential_slug = "/".join(remote_url.split("/")[-2:])

if potential_slug.endswith(".git"):
potential_slug = potential_slug[:-4]

return potential_slug
parsed_repo: GitUrlParsed = parse(remote_url)
parts = [parsed_repo.owner, *parsed_repo.groups, parsed_repo.name]
slug = "/".join(parts)
return slug


@define
Expand Down Expand Up @@ -81,7 +75,7 @@ def set_pr_description(self, body: str) -> None:
...

def create_comment(
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
) -> str | None:
...

Expand Down Expand Up @@ -171,27 +165,27 @@ def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
...

def find_prs(
self,
slug: str,
state: PullRequestState | None = None,
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
self,
slug: str,
state: PullRequestState | None = None,
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
) -> list[PullRequestProtocol]:
...

def create_pr(
self,
slug: str,
title: str,
body: str,
original_branch: str,
feature_branch: str,
self,
slug: str,
title: str,
body: str,
original_branch: str,
feature_branch: str,
) -> PullRequestProtocol:
...

def create_issue_comment(
self, slug: str, issue_text: str, title: str | None = None, issue_id: int | None = None
self, slug: str, issue_text: str, title: str | None = None, issue_id: int | None = None
) -> str:
...

Expand All @@ -212,7 +206,7 @@ def set_pr_description(self, body: str) -> None:
self._mr.save()

def create_comment(
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
) -> str | None:
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"
if path is None:
Expand Down Expand Up @@ -316,7 +310,7 @@ def set_pr_description(self, body: str) -> None:
self._pr.edit(body=final_body)

def create_comment(
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
) -> str | None:
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"

Expand Down Expand Up @@ -416,12 +410,12 @@ def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
return None

def find_prs(
self,
slug: str,
state: PullRequestState | None = None,
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
self,
slug: str,
state: PullRequestState | None = None,
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
) -> list[GithubPullRequest]:
repo = self.github.get_repo(slug)
kwargs_list = dict(state=[None], target_branch=[None], source_branch=[None])
Expand Down Expand Up @@ -454,12 +448,12 @@ def find_prs(
return rv_list

def create_pr(
self,
slug: str,
title: str,
body: str,
original_branch: str,
feature_branch: str,
self,
slug: str,
title: str,
body: str,
original_branch: str,
feature_branch: str,
) -> PullRequestProtocol:
# before creating a PR, check if one already exists
repo = self.github.get_repo(slug)
Expand All @@ -468,7 +462,7 @@ def create_pr(
return pr

def create_issue_comment(
self, slug: str, issue_text: str, title: str | None = None, issue_id: int | None = None
self, slug: str, issue_text: str, title: str | None = None, issue_id: int | None = None
) -> str:
repo = self.github.get_repo(slug)
if issue_id is not None:
Expand Down Expand Up @@ -545,12 +539,12 @@ def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
return None

def find_prs(
self,
slug: str,
state: PullRequestState | None = None,
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
self,
slug: str,
state: PullRequestState | None = None,
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
) -> list[PullRequestProtocol]:
project = self.gitlab.projects.get(slug)
kwargs_list = dict(iterator=[True], state=[None], target_branch=[None], source_branch=[None])
Expand All @@ -576,12 +570,12 @@ def find_prs(
return rv_list

def create_pr(
self,
slug: str,
title: str,
body: str,
original_branch: str,
feature_branch: str,
self,
slug: str,
title: str,
body: str,
original_branch: str,
feature_branch: str,
) -> PullRequestProtocol:
# before creating a PR, check if one already exists
project = self.gitlab.projects.get(slug)
Expand All @@ -598,7 +592,7 @@ def create_pr(
return mr

def create_issue_comment(
self, slug: str, issue_text: str, title: str | None = None, issue_id: int | None = None
self, slug: str, issue_text: str, title: str | None = None, issue_id: int | None = None
) -> str:
if issue_id is not None:
obj = self.gitlab.projects.get(slug).issues.get(issue_id).notes.create({"body": issue_text})
Expand Down
4 changes: 3 additions & 1 deletion patchwork/patchflows/PRReview/PRReview.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def run(self) -> dict:

header = ""
if self.verbosity > _SUMMARY_LEVEL[_SHORT]:
filtered_summaries = [str(summary["commit_message"]) for summary in summaries if summary.get("commit_message")]
filtered_summaries = [
str(summary["commit_message"]) for summary in summaries if summary.get("commit_message")
]
self.inputs["prompt_id"] = "diffreview_summary"
self.inputs["prompt_values"] = [{"diffreviews": "\n".join(filtered_summaries)}]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from pathlib import Path

from patchwork.step import Step
from patchwork.steps.GetTypescriptTypeInfo.typed import GetTypescriptTypeInfoInputs, GetTypescriptTypeInfoOutputs

from patchwork.steps.GetTypescriptTypeInfo.typed import (
GetTypescriptTypeInfoInputs,
GetTypescriptTypeInfoOutputs,
)

_DEFAULT_TS_FILE = Path(__file__).parent / "get_type_info.ts"

Expand Down
4 changes: 3 additions & 1 deletion patchwork/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
GenerateCodeRepositoryEmbeddings,
)
from patchwork.steps.GenerateEmbeddings.GenerateEmbeddings import GenerateEmbeddings
from patchwork.steps.GetTypescriptTypeInfo.GetTypescriptTypeInfo import (
GetTypescriptTypeInfo,
)
from patchwork.steps.JoinList.JoinList import JoinList
from patchwork.steps.JoinListPB.JoinListPB import JoinListPB
from patchwork.steps.LLM.LLM import LLM
Expand All @@ -46,7 +49,6 @@
from patchwork.steps.SimplifiedLLMOnce.SimplifiedLLMOnce import SimplifiedLLMOnce
from patchwork.steps.SimplifiedLLMOncePB.SimplifiedLLMOncePB import SimplifiedLLMOncePB
from patchwork.steps.SlackMessage.SlackMessage import SlackMessage
from patchwork.steps.GetTypescriptTypeInfo.GetTypescriptTypeInfo import GetTypescriptTypeInfo

__all__ = [
"AnalyzeImpact",
Expand Down
Loading

0 comments on commit 6a3f8d7

Please sign in to comment.